JFLa commited on
Commit
f22cc6f
·
verified ·
1 Parent(s): a8f93e1

Upload 14 files

Browse files

Geneformer backbone by Theodoris et al.

geneformer/.DS_Store ADDED
Binary file (6.15 kB). View file
 
geneformer/Dataset_Create.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from datasets import load_from_disk
3
+ import torch
4
+ from transformers import BertForMaskedLM
5
+ import os
6
+ import sys
7
+ from tqdm.notebook import tqdm
8
+ import seaborn as sns
9
+ import matplotlib.pyplot as plt
10
+
11
+ # sys.path.append('/Users/chenj0i/Desktop/Lab Work/Geneformer')
12
+ from geneformer.pretrainer import token_dictionary
13
+
14
+ import datetime
15
+
16
+ # imports
17
+ import os
18
+ import time
19
+
20
+ os.environ["NCCL_DEBUG"] = "INFO"
21
+ os.environ["OMPI_MCA_opal_cuda_support"] = "true"
22
+ os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
23
+
24
+ import pickle
25
+ import random
26
+ import subprocess
27
+
28
+ import numpy as np
29
+ import pytz
30
+ import torch
31
+ from datasets import load_from_disk, Dataset
32
+ from transformers import BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback, Trainer, BertModel, BertPreTrainedModel
33
+
34
+ from geneformer import GeneformerPretrainer
35
+
36
+ from typing import Tuple
37
+ from torch import Tensor
38
+ from transformers.modeling_outputs import MaskedLMOutput
39
+ from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform
40
+ from transformers.activations import ACT2FN
41
+ from typing import List, Optional, Tuple, Union
42
+ import torch.nn.functional as F
43
+
44
+
45
+ # # Randomly select 100_000 sequences from Genecorpus to conduct the training
46
+ genecorpus = load_from_disk("/ibex/user/chenj0i/Geneformer/Genecorpus-30M/genecorpus_30M_2048.dataset")
47
+
48
+ subset_size = 1_200_000
49
+ subset_sequences = genecorpus.shuffle(seed=42).select(i for i in tqdm(list(range(subset_size))))['input_ids']
50
+ subset_train_dataset = Dataset.from_dict({"input_ids": subset_sequences[:-200_000]})
51
+ subset_train_dataset.save_to_disk("/ibex/user/chenj0i/Geneformer/subset_1Mtrain_genecorpus.dataset")
52
+ subset_test_dataset = Dataset.from_dict({"input_ids": subset_sequences[-200_000:]})
53
+ subset_test_dataset.save_to_disk("/ibex/user/chenj0i/Geneformer/subset_200K_1Mtrain_genecorpus.dataset")
54
+
55
+ # Create length file for the training
56
+ # Define the value to repeat
57
+ value_to_repeat = 2048
58
+ # Define the total number of elements
59
+ total_elements = 1_000_000
60
+ # Create the list with repeated values
61
+ data_list = [value_to_repeat] * total_elements
62
+ # Define the path for the output .pkl length file
63
+ output_file = "sub_1Mtrain_genecorpus_30M_2048_lengths.pkl"
64
+ # Save the list to a .pkl file
65
+ with open(output_file, 'wb') as f:
66
+ pickle.dump(data_list, f)
67
+ print(f"List with {subset_size} elements saved as {output_file}")
68
+
69
+ value_to_repeat_test = 2048
70
+ # Define the total number of elements
71
+ total_elements_test = 200_000
72
+ # Create the list with repeated values
73
+ data_list_test = [value_to_repeat_test] * total_elements_test
74
+ # Define the path for the output .pkl length file
75
+ output_file_test = "sub_200K_1Mtrain_genecorpus_30M_2048_lengths.pkl"
76
+ # Save the list to a .pkl file
77
+ with open(output_file_test, 'wb') as f:
78
+ pickle.dump(data_list_test, f)
79
+ print(f"List with {subset_size} elements saved as {output_file_test}")
geneformer/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import tokenizer
2
+ from . import pretrainer
3
+ from . import collator_for_classification
4
+ from . import in_silico_perturber
5
+ from . import in_silico_perturber_stats
6
+ from .tokenizer import TranscriptomeTokenizer
7
+ from .pretrainer import GeneformerPretrainer
8
+ from .collator_for_classification import DataCollatorForGeneClassification
9
+ from .collator_for_classification import DataCollatorForCellClassification
10
+ from .emb_extractor import EmbExtractor
11
+ from .in_silico_perturber import InSilicoPerturber
12
+ from .in_silico_perturber_stats import InSilicoPerturberStats
geneformer/collator_for_classification.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer collator for gene and cell classification.
3
+
4
+ Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
5
+ """
6
+ import numpy as np
7
+ import torch
8
+ import warnings
9
+ from enum import Enum
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ from transformers import (
13
+ DataCollatorForTokenClassification,
14
+ SpecialTokensMixin,
15
+ BatchEncoding,
16
+ )
17
+ from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
+ from transformers.utils.generic import _is_tensorflow, _is_torch
19
+
20
+ from .pretrainer import token_dictionary
21
+
22
+ EncodedInput = List[int]
23
+ logger = logging.get_logger(__name__)
24
+ VERY_LARGE_INTEGER = int(
25
+ 1e30
26
+ ) # This is used to set the max input length for a model with infinite size input
27
+ LARGE_INTEGER = int(
28
+ 1e20
29
+ ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
30
+
31
+ # precollator functions
32
+
33
+ class ExplicitEnum(Enum):
34
+ """
35
+ Enum with more explicit error message for missing values.
36
+ """
37
+
38
+ @classmethod
39
+ def _missing_(cls, value):
40
+ raise ValueError(
41
+ "%r is not a valid %s, please select one of %s"
42
+ % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
43
+ )
44
+
45
+ class TruncationStrategy(ExplicitEnum):
46
+ """
47
+ Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
48
+ tab-completion in an IDE.
49
+ """
50
+
51
+ ONLY_FIRST = "only_first"
52
+ ONLY_SECOND = "only_second"
53
+ LONGEST_FIRST = "longest_first"
54
+ DO_NOT_TRUNCATE = "do_not_truncate"
55
+
56
+
57
+
58
+ class PaddingStrategy(ExplicitEnum):
59
+ """
60
+ Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
61
+ in an IDE.
62
+ """
63
+
64
+ LONGEST = "longest"
65
+ MAX_LENGTH = "max_length"
66
+ DO_NOT_PAD = "do_not_pad"
67
+
68
+
69
+
70
+ class TensorType(ExplicitEnum):
71
+ """
72
+ Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
73
+ tab-completion in an IDE.
74
+ """
75
+
76
+ PYTORCH = "pt"
77
+ TENSORFLOW = "tf"
78
+ NUMPY = "np"
79
+ JAX = "jax"
80
+
81
+
82
+ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
83
+ mask_token = "<mask>"
84
+ mask_token_id = token_dictionary.get("<mask>")
85
+ pad_token = "<pad>"
86
+ pad_token_id = token_dictionary.get("<pad>")
87
+ padding_side = "right"
88
+ all_special_ids = [
89
+ token_dictionary.get("<mask>"),
90
+ token_dictionary.get("<pad>")
91
+ ]
92
+ model_input_names = ["input_ids"]
93
+
94
+ def _get_padding_truncation_strategies(
95
+ self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
96
+ ):
97
+ """
98
+ Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
99
+ and pad_to_max_length) and behaviors.
100
+ """
101
+ old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
102
+ old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
103
+
104
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
105
+ # If you only set max_length, it activates truncation for max_length
106
+ if max_length is not None and padding is False and truncation is False:
107
+ if verbose:
108
+ if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
109
+ logger.warning(
110
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, "
111
+ "please use `truncation=True` to explicitly truncate examples to max length. "
112
+ "Defaulting to 'longest_first' truncation strategy. "
113
+ "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
114
+ "more precisely by providing a specific strategy to `truncation`."
115
+ )
116
+ self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
117
+ truncation = "longest_first"
118
+
119
+ # Get padding strategy
120
+ if padding is False and old_pad_to_max_length:
121
+ if verbose:
122
+ warnings.warn(
123
+ "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
124
+ "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
125
+ "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
126
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
127
+ "maximal input size of the model (e.g. 512 for Bert).",
128
+ FutureWarning,
129
+ )
130
+ if max_length is None:
131
+ padding_strategy = PaddingStrategy.LONGEST
132
+ else:
133
+ padding_strategy = PaddingStrategy.MAX_LENGTH
134
+ elif padding is not False:
135
+ if padding is True:
136
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
137
+ elif not isinstance(padding, PaddingStrategy):
138
+ padding_strategy = PaddingStrategy(padding)
139
+ elif isinstance(padding, PaddingStrategy):
140
+ padding_strategy = padding
141
+ else:
142
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
143
+
144
+ # Get truncation strategy
145
+ if truncation is False and old_truncation_strategy != "do_not_truncate":
146
+ if verbose:
147
+ warnings.warn(
148
+ "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
149
+ "use `truncation=True` to truncate examples to a max length. You can give a specific "
150
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
151
+ "maximal input size of the model (e.g. 512 for Bert). "
152
+ " If you have pairs of inputs, you can give a specific truncation strategy selected among "
153
+ "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
154
+ "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
155
+ "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
156
+ FutureWarning,
157
+ )
158
+ truncation_strategy = TruncationStrategy(old_truncation_strategy)
159
+ elif truncation is not False:
160
+ if truncation is True:
161
+ truncation_strategy = (
162
+ TruncationStrategy.LONGEST_FIRST
163
+ ) # Default to truncate the longest sequences in pairs of inputs
164
+ elif not isinstance(truncation, TruncationStrategy):
165
+ truncation_strategy = TruncationStrategy(truncation)
166
+ elif isinstance(truncation, TruncationStrategy):
167
+ truncation_strategy = truncation
168
+ else:
169
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
170
+
171
+ # Set max length if needed
172
+ if max_length is None:
173
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
174
+ if self.model_max_length > LARGE_INTEGER:
175
+ if verbose:
176
+ if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
177
+ logger.warning(
178
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
179
+ "Default to no padding."
180
+ )
181
+ self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
182
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
183
+ else:
184
+ max_length = self.model_max_length
185
+
186
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
187
+ if self.model_max_length > LARGE_INTEGER:
188
+ if verbose:
189
+ if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
190
+ logger.warning(
191
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
192
+ "Default to no truncation."
193
+ )
194
+ self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
195
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
196
+ else:
197
+ max_length = self.model_max_length
198
+
199
+ # Test if we have a padding token
200
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
201
+ raise ValueError(
202
+ "Asking to pad but the tokenizer does not have a padding token. "
203
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
204
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
205
+ )
206
+
207
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
208
+ if (
209
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
210
+ and padding_strategy != PaddingStrategy.DO_NOT_PAD
211
+ and pad_to_multiple_of is not None
212
+ and max_length is not None
213
+ and (max_length % pad_to_multiple_of != 0)
214
+ ):
215
+ raise ValueError(
216
+ f"Truncation and padding are both activated but "
217
+ f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
218
+ )
219
+
220
+ return padding_strategy, truncation_strategy, max_length, kwargs
221
+
222
+ def pad(
223
+ self,
224
+ encoded_inputs: Union[
225
+ BatchEncoding,
226
+ List[BatchEncoding],
227
+ Dict[str, EncodedInput],
228
+ Dict[str, List[EncodedInput]],
229
+ List[Dict[str, EncodedInput]],
230
+ ],
231
+ class_type, # options: "gene" or "cell"
232
+ padding: Union[bool, str, PaddingStrategy] = True,
233
+ max_length: Optional[int] = None,
234
+ pad_to_multiple_of: Optional[int] = None,
235
+ return_attention_mask: Optional[bool] = True,
236
+ return_tensors: Optional[Union[str, TensorType]] = None,
237
+ verbose: bool = True,
238
+ ) -> BatchEncoding:
239
+ """
240
+ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
241
+ in the batch.
242
+
243
+ Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
244
+ ``self.pad_token_id`` and ``self.pad_token_type_id``)
245
+
246
+ .. note::
247
+
248
+ If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
249
+ result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
250
+ case of PyTorch tensors, you will lose the specific device of your tensors however.
251
+
252
+ Args:
253
+ encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
254
+ Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
255
+ List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
256
+ List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
257
+ well as in a PyTorch Dataloader collate function.
258
+
259
+ Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
260
+ see the note above for the return type.
261
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
262
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
263
+ index) among:
264
+
265
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
266
+ single sequence if provided).
267
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
268
+ maximum acceptable input length for the model if that argument is not provided.
269
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
270
+ different lengths).
271
+ max_length (:obj:`int`, `optional`):
272
+ Maximum length of the returned list and optionally padding length (see above).
273
+ pad_to_multiple_of (:obj:`int`, `optional`):
274
+ If set will pad the sequence to a multiple of the provided value.
275
+
276
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
277
+ >= 7.5 (Volta).
278
+ return_attention_mask (:obj:`bool`, `optional`):
279
+ Whether to return the attention mask. If left to the default, will return the attention mask according
280
+ to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
281
+
282
+ `What are attention masks? <../glossary.html#attention-mask>`__
283
+ return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
284
+ If set, will return tensors instead of list of python integers. Acceptable values are:
285
+
286
+ * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
287
+ * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
288
+ * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
289
+ verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
290
+ Whether or not to print more information and warnings.
291
+ """
292
+ # If we have a list of dicts, let's convert it in a dict of lists
293
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
294
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
295
+ encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
296
+
297
+ # The model's main input name, usually `input_ids`, has be passed for padding
298
+ if self.model_input_names[0] not in encoded_inputs:
299
+ raise ValueError(
300
+ "You should supply an encoding or a list of encodings to this method"
301
+ f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
302
+ )
303
+
304
+ required_input = encoded_inputs[self.model_input_names[0]]
305
+
306
+ if not required_input:
307
+ if return_attention_mask:
308
+ encoded_inputs["attention_mask"] = []
309
+ return encoded_inputs
310
+
311
+ # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
312
+ # and rebuild them afterwards if no return_tensors is specified
313
+ # Note that we lose the specific device the tensor may be on for PyTorch
314
+
315
+ first_element = required_input[0]
316
+ if isinstance(first_element, (list, tuple)):
317
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
318
+ index = 0
319
+ while len(required_input[index]) == 0:
320
+ index += 1
321
+ if index < len(required_input):
322
+ first_element = required_input[index][0]
323
+ # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
324
+ if not isinstance(first_element, (int, list, tuple)):
325
+ if is_tf_available() and _is_tensorflow(first_element):
326
+ return_tensors = "tf" if return_tensors is None else return_tensors
327
+ elif is_torch_available() and _is_torch(first_element):
328
+ return_tensors = "pt" if return_tensors is None else return_tensors
329
+ elif isinstance(first_element, np.ndarray):
330
+ return_tensors = "np" if return_tensors is None else return_tensors
331
+ else:
332
+ raise ValueError(
333
+ f"type of {first_element} unknown: {type(first_element)}. "
334
+ f"Should be one of a python, numpy, pytorch or tensorflow object."
335
+ )
336
+
337
+ for key, value in encoded_inputs.items():
338
+ encoded_inputs[key] = to_py_obj(value)
339
+
340
+ # Convert padding_strategy in PaddingStrategy
341
+ padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
342
+ padding=padding, max_length=max_length, verbose=verbose
343
+ )
344
+
345
+ required_input = encoded_inputs[self.model_input_names[0]]
346
+ if required_input and not isinstance(required_input[0], (list, tuple)):
347
+ encoded_inputs = self._pad(
348
+ encoded_inputs,
349
+ class_type=class_type,
350
+ max_length=max_length,
351
+ padding_strategy=padding_strategy,
352
+ pad_to_multiple_of=pad_to_multiple_of,
353
+ return_attention_mask=return_attention_mask,
354
+ )
355
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
356
+
357
+ batch_size = len(required_input)
358
+ assert all(
359
+ len(v) == batch_size for v in encoded_inputs.values()
360
+ ), "Some items in the output dictionary have a different batch size than others."
361
+
362
+ if padding_strategy == PaddingStrategy.LONGEST:
363
+ max_length = max(len(inputs) for inputs in required_input)
364
+ padding_strategy = PaddingStrategy.MAX_LENGTH
365
+
366
+ batch_outputs = {}
367
+ for i in range(batch_size):
368
+ inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
369
+ outputs = self._pad(
370
+ inputs,
371
+ class_type=class_type,
372
+ max_length=max_length,
373
+ padding_strategy=padding_strategy,
374
+ pad_to_multiple_of=pad_to_multiple_of,
375
+ return_attention_mask=return_attention_mask,
376
+ )
377
+
378
+ for key, value in outputs.items():
379
+ if key not in batch_outputs:
380
+ batch_outputs[key] = []
381
+ batch_outputs[key].append(value)
382
+ if class_type == "cell":
383
+ del batch_outputs["label"]
384
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
385
+
386
+ def _pad(
387
+ self,
388
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
389
+ class_type, # options: "gene" or "cell"
390
+ max_length: Optional[int] = None,
391
+ padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
392
+ pad_to_multiple_of: Optional[int] = None,
393
+ return_attention_mask: Optional[bool] = True,
394
+ ) -> dict:
395
+ """
396
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
397
+
398
+ Args:
399
+ encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
400
+ max_length: maximum length of the returned list and optionally padding length (see below).
401
+ Will truncate by taking into account the special tokens.
402
+ padding_strategy: PaddingStrategy to use for padding.
403
+
404
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
405
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
406
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
407
+ The tokenizer padding sides are defined in self.padding_side:
408
+
409
+ - 'left': pads on the left of the sequences
410
+ - 'right': pads on the right of the sequences
411
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
412
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
413
+ >= 7.5 (Volta).
414
+ return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
415
+ """
416
+ # Load from model defaults
417
+ if return_attention_mask is None:
418
+ return_attention_mask = "attention_mask" in self.model_input_names
419
+
420
+ required_input = encoded_inputs[self.model_input_names[0]]
421
+
422
+ if padding_strategy == PaddingStrategy.LONGEST:
423
+ max_length = len(required_input)
424
+
425
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
426
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
427
+
428
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
429
+
430
+ if needs_to_be_padded:
431
+ difference = max_length - len(required_input)
432
+ if self.padding_side == "right":
433
+ if return_attention_mask:
434
+ encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
435
+ if "token_type_ids" in encoded_inputs:
436
+ encoded_inputs["token_type_ids"] = (
437
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
438
+ )
439
+ if "special_tokens_mask" in encoded_inputs:
440
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
441
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
442
+ if class_type == "gene":
443
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference
444
+ elif self.padding_side == "left":
445
+ if return_attention_mask:
446
+ encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
447
+ if "token_type_ids" in encoded_inputs:
448
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
449
+ "token_type_ids"
450
+ ]
451
+ if "special_tokens_mask" in encoded_inputs:
452
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
453
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
454
+ if class_type == "gene":
455
+ encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"]
456
+ else:
457
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
458
+ elif return_attention_mask and "attention_mask" not in encoded_inputs:
459
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
460
+
461
+ return encoded_inputs
462
+
463
+ def get_special_tokens_mask(
464
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
465
+ ) -> List[int]:
466
+ """
467
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
468
+ special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
469
+ Args:
470
+ token_ids_0 (:obj:`List[int]`):
471
+ List of ids of the first sequence.
472
+ token_ids_1 (:obj:`List[int]`, `optional`):
473
+ List of ids of the second sequence.
474
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
475
+ Whether or not the token list is already formatted with special tokens for the model.
476
+ Returns:
477
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
478
+ """
479
+ assert already_has_special_tokens and token_ids_1 is None, (
480
+ "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
481
+ "Please use a slow (full python) tokenizer to activate this argument."
482
+ "Or set `return_special_tokens_mask=True` when calling the encoding method "
483
+ "to get the special tokens mask in any tokenizer. "
484
+ )
485
+
486
+ all_special_ids = self.all_special_ids # cache the property
487
+
488
+ special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
489
+
490
+ return special_tokens_mask
491
+
492
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
493
+ """
494
+ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
495
+ vocabulary.
496
+ Args:
497
+ tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
498
+ Returns:
499
+ :obj:`int` or :obj:`List[int]`: The token id or list of token ids.
500
+ """
501
+ if tokens is None:
502
+ return None
503
+
504
+ if isinstance(tokens, str):
505
+ return self._convert_token_to_id_with_added_voc(tokens)
506
+
507
+ ids = []
508
+ for token in tokens:
509
+ ids.append(self._convert_token_to_id_with_added_voc(token))
510
+ return ids
511
+
512
+ def _convert_token_to_id_with_added_voc(self, token):
513
+ if token is None:
514
+ return None
515
+
516
+ return token_dictionary.get(token)
517
+
518
+ def __len__(self):
519
+ return len(token_dictionary)
520
+
521
+
522
+ # collator functions
523
+
524
+ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
525
+ """
526
+ Data collator that will dynamically pad the inputs received, as well as the labels.
527
+ Args:
528
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
529
+ The tokenizer used for encoding the data.
530
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
531
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
532
+ among:
533
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
534
+ sequence if provided).
535
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
536
+ maximum acceptable input length for the model if that argument is not provided.
537
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
538
+ different lengths).
539
+ max_length (:obj:`int`, `optional`):
540
+ Maximum length of the returned list and optionally padding length (see above).
541
+ pad_to_multiple_of (:obj:`int`, `optional`):
542
+ If set will pad the sequence to a multiple of the provided value.
543
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
544
+ 7.5 (Volta).
545
+ label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
546
+ The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
547
+ """
548
+
549
+ tokenizer = PrecollatorForGeneAndCellClassification()
550
+ class_type = "gene"
551
+ padding: Union[bool, str, PaddingStrategy] = True
552
+ max_length: Optional[int] = None
553
+ pad_to_multiple_of: Optional[int] = None
554
+ label_pad_token_id: int = -100
555
+
556
+ def __init__(self, *args, **kwargs) -> None:
557
+ super().__init__(
558
+ tokenizer=self.tokenizer,
559
+ padding=self.padding,
560
+ max_length=self.max_length,
561
+ pad_to_multiple_of=self.pad_to_multiple_of,
562
+ label_pad_token_id=self.label_pad_token_id,
563
+ *args, **kwargs)
564
+
565
+ def _prepare_batch(self, features):
566
+ label_name = "label" if "label" in features[0].keys() else "labels"
567
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
568
+ batch = self.tokenizer.pad(
569
+ features,
570
+ class_type=self.class_type,
571
+ padding=self.padding,
572
+ max_length=self.max_length,
573
+ pad_to_multiple_of=self.pad_to_multiple_of,
574
+ return_tensors="pt",
575
+ )
576
+ return batch
577
+
578
+ def __call__(self, features):
579
+ batch = self._prepare_batch(features)
580
+
581
+ batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
582
+ return batch
583
+
584
+
585
+ class DataCollatorForCellClassification(DataCollatorForGeneClassification):
586
+
587
+ class_type = "cell"
588
+
589
+ def _prepare_batch(self, features):
590
+
591
+ batch = super()._prepare_batch(features)
592
+
593
+ # Special handling for labels.
594
+ # Ensure that tensor is created with the correct type
595
+ # (it should be automatically the case, but let's make sure of it.)
596
+ first = features[0]
597
+ if "label" in first and first["label"] is not None:
598
+ label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
599
+ dtype = torch.long if isinstance(label, int) else torch.float
600
+ batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
601
+
602
+ return batch
geneformer/emb_extractor.py ADDED
@@ -0,0 +1,806 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer embedding extractor.
3
+
4
+ **Description:**
5
+
6
+ | Extracts gene or cell embeddings.
7
+ | Plots cell embeddings as heatmaps or UMAPs.
8
+ | Generates cell state embedding dictionary for use with InSilicoPerturber.
9
+
10
+ """
11
+
12
+ # imports
13
+ import logging
14
+ import pickle
15
+ from collections import Counter
16
+ from pathlib import Path
17
+
18
+ import anndata
19
+ import matplotlib.pyplot as plt
20
+ import numpy as np
21
+ import pandas as pd
22
+ import scanpy as sc
23
+ import seaborn as sns
24
+ import torch
25
+ from tdigest import TDigest
26
+ from tqdm.auto import trange
27
+
28
+ from . import perturber_utils as pu
29
+ from .tokenizer import TOKEN_DICTIONARY_FILE
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ # extract embeddings
35
+ def get_embs(
36
+ model,
37
+ filtered_input_data,
38
+ emb_mode,
39
+ layer_to_quant,
40
+ pad_token_id,
41
+ forward_batch_size,
42
+ summary_stat=None,
43
+ silent=False,
44
+ ):
45
+ model_input_size = pu.get_model_input_size(model)
46
+ total_batch_length = len(filtered_input_data)
47
+
48
+ if summary_stat is None:
49
+ embs_list = []
50
+ elif summary_stat is not None:
51
+ # test embedding extraction for example cell and extract # emb dims
52
+ example = filtered_input_data.select([i for i in range(1)])
53
+ example.set_format(type="torch")
54
+ emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
55
+ if emb_mode == "cell":
56
+ # initiate tdigests for # of emb dims
57
+ embs_tdigests = [TDigest() for _ in range(emb_dims)]
58
+ if emb_mode == "gene":
59
+ gene_set = list(
60
+ {
61
+ element
62
+ for sublist in filtered_input_data["input_ids"]
63
+ for element in sublist
64
+ }
65
+ )
66
+ # initiate dict with genes as keys and tdigests for # of emb dims as values
67
+ embs_tdigests_dict = {
68
+ k: [TDigest() for _ in range(emb_dims)] for k in gene_set
69
+ }
70
+
71
+ overall_max_len = 0
72
+
73
+ for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
74
+ max_range = min(i + forward_batch_size, total_batch_length)
75
+
76
+ minibatch = filtered_input_data.select([i for i in range(i, max_range)])
77
+
78
+ max_len = int(max(minibatch["length"]))
79
+ original_lens = torch.tensor(minibatch["length"], device="cuda")
80
+ minibatch.set_format(type="torch")
81
+
82
+ input_data_minibatch = minibatch["input_ids"]
83
+ input_data_minibatch = pu.pad_tensor_list(
84
+ input_data_minibatch, max_len, pad_token_id, model_input_size
85
+ )
86
+
87
+ with torch.no_grad():
88
+ outputs = model(
89
+ input_ids=input_data_minibatch.to("cuda"),
90
+ attention_mask=pu.gen_attention_mask(minibatch),
91
+ )
92
+
93
+ embs_i = outputs.hidden_states[layer_to_quant]
94
+
95
+ if emb_mode == "cell":
96
+ mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
97
+ if summary_stat is None:
98
+ embs_list.append(mean_embs)
99
+ elif summary_stat is not None:
100
+ # update tdigests with current batch for each emb dim
101
+ accumulate_tdigests(embs_tdigests, mean_embs, emb_dims)
102
+ del mean_embs
103
+ elif emb_mode == "gene":
104
+ if summary_stat is None:
105
+ embs_list.append(embs_i)
106
+ elif summary_stat is not None:
107
+ for h in trange(len(minibatch)):
108
+ length_h = minibatch[h]["length"]
109
+ input_ids_h = minibatch[h]["input_ids"][0:length_h]
110
+
111
+ # double check dimensions before unsqueezing
112
+ embs_i_dim = embs_i.dim()
113
+ if embs_i_dim != 3:
114
+ logger.error(
115
+ f"Embedding tensor should have 3 dimensions, not {embs_i_dim}"
116
+ )
117
+ raise
118
+
119
+ embs_h = embs_i[h, :, :].unsqueeze(dim=1)
120
+ dict_h = dict(zip(input_ids_h, embs_h))
121
+ for k in dict_h.keys():
122
+ accumulate_tdigests(
123
+ embs_tdigests_dict[int(k)], dict_h[k], emb_dims
124
+ )
125
+
126
+ overall_max_len = max(overall_max_len, max_len)
127
+ del outputs
128
+ del minibatch
129
+ del input_data_minibatch
130
+ del embs_i
131
+
132
+ torch.cuda.empty_cache()
133
+
134
+ if summary_stat is None:
135
+ if emb_mode == "cell":
136
+ embs_stack = torch.cat(embs_list, dim=0)
137
+ elif emb_mode == "gene":
138
+ embs_stack = pu.pad_tensor_list(
139
+ embs_list,
140
+ overall_max_len,
141
+ pad_token_id,
142
+ model_input_size,
143
+ 1,
144
+ pu.pad_3d_tensor,
145
+ )
146
+
147
+ # calculate summary stat embs from approximated tdigests
148
+ elif summary_stat is not None:
149
+ if emb_mode == "cell":
150
+ if summary_stat == "mean":
151
+ summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
152
+ elif summary_stat == "median":
153
+ summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
154
+ embs_stack = torch.tensor(summary_emb_list)
155
+ elif emb_mode == "gene":
156
+ if summary_stat == "mean":
157
+ [
158
+ update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
159
+ for gene in embs_tdigests_dict.keys()
160
+ ]
161
+ elif summary_stat == "median":
162
+ [
163
+ update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims)
164
+ for gene in embs_tdigests_dict.keys()
165
+ ]
166
+ return embs_tdigests_dict
167
+
168
+ return embs_stack
169
+
170
+
171
+ def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
172
+ # note: tdigest batch update known to be slow so updating serially
173
+ [
174
+ embs_tdigests[j].update(mean_embs[i, j].item())
175
+ for i in range(mean_embs.size(0))
176
+ for j in range(emb_dims)
177
+ ]
178
+
179
+
180
+ def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
181
+ embs_tdigests_dict[gene] = accumulate_tdigests(
182
+ embs_tdigests_dict[gene], gene_embs, emb_dims
183
+ )
184
+
185
+
186
+ def update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims):
187
+ embs_tdigests_dict[gene] = tdigest_mean(embs_tdigests_dict[gene], emb_dims)
188
+
189
+
190
+ def update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims):
191
+ embs_tdigests_dict[gene] = tdigest_median(embs_tdigests_dict[gene], emb_dims)
192
+
193
+
194
+ def summarize_gene_embs(h, minibatch, embs_i, embs_tdigests_dict, emb_dims):
195
+ length_h = minibatch[h]["length"]
196
+ input_ids_h = minibatch[h]["input_ids"][0:length_h]
197
+ embs_h = embs_i[h, :, :].unsqueeze(dim=1)
198
+ dict_h = dict(zip(input_ids_h, embs_h))
199
+ [
200
+ update_tdigest_dict(embs_tdigests_dict, k, dict_h[k], emb_dims)
201
+ for k in dict_h.keys()
202
+ ]
203
+
204
+
205
+ def tdigest_mean(embs_tdigests, emb_dims):
206
+ return [embs_tdigests[i].trimmed_mean(0, 100) for i in range(emb_dims)]
207
+
208
+
209
+ def tdigest_median(embs_tdigests, emb_dims):
210
+ return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
211
+
212
+
213
+ def test_emb(model, example, layer_to_quant):
214
+ with torch.no_grad():
215
+ outputs = model(input_ids=example.to("cuda"))
216
+
217
+ embs_test = outputs.hidden_states[layer_to_quant]
218
+ return embs_test.size()[2]
219
+
220
+
221
+ def label_cell_embs(embs, downsampled_data, emb_labels):
222
+ embs_df = pd.DataFrame(embs.cpu().numpy())
223
+ if emb_labels is not None:
224
+ for label in emb_labels:
225
+ emb_label = downsampled_data[label]
226
+ embs_df[label] = emb_label
227
+ return embs_df
228
+
229
+
230
+ def label_gene_embs(embs, downsampled_data, token_gene_dict):
231
+ gene_set = {
232
+ element for sublist in downsampled_data["input_ids"] for element in sublist
233
+ }
234
+ gene_emb_dict = {k: [] for k in gene_set}
235
+ for i in range(embs.size()[0]):
236
+ length = downsampled_data[i]["length"]
237
+ dict_i = dict(
238
+ zip(
239
+ downsampled_data[i]["input_ids"][0:length],
240
+ embs[i, :, :].unsqueeze(dim=1),
241
+ )
242
+ )
243
+ for k in dict_i.keys():
244
+ gene_emb_dict[k].append(dict_i[k])
245
+ for k in gene_emb_dict.keys():
246
+ gene_emb_dict[k] = (
247
+ torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
248
+ .cpu()
249
+ .numpy()
250
+ )
251
+ embs_df = pd.DataFrame(gene_emb_dict).T
252
+ embs_df.index = [token_gene_dict[token] for token in embs_df.index]
253
+ return embs_df
254
+
255
+
256
+ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
257
+ only_embs_df = embs_df.iloc[:, :emb_dims]
258
+ only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
259
+ only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
260
+ str
261
+ )
262
+ vars_dict = {"embs": only_embs_df.columns}
263
+ obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
264
+ adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
265
+ sc.tl.pca(adata, svd_solver="arpack")
266
+ sc.pp.neighbors(adata)
267
+ sc.tl.umap(adata)
268
+ sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
269
+ sns.set_style("white")
270
+ default_kwargs_dict = {"palette": "Set2", "size": 200}
271
+ if kwargs_dict is not None:
272
+ default_kwargs_dict.update(kwargs_dict)
273
+
274
+ sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
275
+
276
+
277
+ def gen_heatmap_class_colors(labels, df):
278
+ pal = sns.cubehelix_palette(
279
+ len(Counter(labels).keys()),
280
+ light=0.9,
281
+ dark=0.1,
282
+ hue=1,
283
+ reverse=True,
284
+ start=1,
285
+ rot=-2,
286
+ )
287
+ lut = dict(zip(map(str, Counter(labels).keys()), pal))
288
+ colors = pd.Series(labels, index=df.index).map(lut)
289
+ return colors
290
+
291
+
292
+ def gen_heatmap_class_dict(classes, label_colors_series):
293
+ class_color_dict_df = pd.DataFrame(
294
+ {"classes": classes, "color": label_colors_series}
295
+ )
296
+ class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
297
+ return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"]))
298
+
299
+
300
+ def make_colorbar(embs_df, label):
301
+ labels = list(embs_df[label])
302
+
303
+ cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
304
+ label_colors = pd.DataFrame(cell_type_colors, columns=[label])
305
+
306
+ for i, row in label_colors.iterrows():
307
+ colors = row[0]
308
+ if len(colors) != 3 or any(np.isnan(colors)):
309
+ print(i, colors)
310
+
311
+ label_colors.isna().sum()
312
+
313
+ # create dictionary for colors and classes
314
+ label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
315
+ return label_colors, label_color_dict
316
+
317
+
318
+ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
319
+ sns.set_style("white")
320
+ sns.set(font_scale=2)
321
+ plt.figure(figsize=(15, 15), dpi=150)
322
+ label_colors, label_color_dict = make_colorbar(embs_df, label)
323
+
324
+ default_kwargs_dict = {
325
+ "row_cluster": True,
326
+ "col_cluster": True,
327
+ "row_colors": label_colors,
328
+ "standard_scale": 1,
329
+ "linewidths": 0,
330
+ "xticklabels": False,
331
+ "yticklabels": False,
332
+ "figsize": (15, 15),
333
+ "center": 0,
334
+ "cmap": "magma",
335
+ }
336
+
337
+ if kwargs_dict is not None:
338
+ default_kwargs_dict.update(kwargs_dict)
339
+ g = sns.clustermap(
340
+ embs_df.iloc[:, 0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict
341
+ )
342
+
343
+ plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
344
+
345
+ for label_color in list(label_color_dict.keys()):
346
+ g.ax_col_dendrogram.bar(
347
+ 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
348
+ )
349
+
350
+ g.ax_col_dendrogram.legend(
351
+ title=f"{label}",
352
+ loc="lower center",
353
+ ncol=4,
354
+ bbox_to_anchor=(0.5, 1),
355
+ facecolor="white",
356
+ )
357
+
358
+ plt.savefig(output_file, bbox_inches="tight")
359
+
360
+
361
+ class EmbExtractor:
362
+ valid_option_dict = {
363
+ "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
364
+ "num_classes": {int},
365
+ "emb_mode": {"cell", "gene"},
366
+ "cell_emb_style": {"mean_pool"},
367
+ "gene_emb_style": {"mean_pool"},
368
+ "filter_data": {None, dict},
369
+ "max_ncells": {None, int},
370
+ "emb_layer": {-1, 0},
371
+ "emb_label": {None, list},
372
+ "labels_to_plot": {None, list},
373
+ "forward_batch_size": {int},
374
+ "nproc": {int},
375
+ "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
376
+ }
377
+
378
+ def __init__(
379
+ self,
380
+ model_type="Pretrained",
381
+ num_classes=0,
382
+ emb_mode="cell",
383
+ cell_emb_style="mean_pool",
384
+ gene_emb_style="mean_pool",
385
+ filter_data=None,
386
+ max_ncells=1000,
387
+ emb_layer=-1,
388
+ emb_label=None,
389
+ labels_to_plot=None,
390
+ forward_batch_size=100,
391
+ nproc=4,
392
+ summary_stat=None,
393
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
394
+ ):
395
+ """
396
+ Initialize embedding extractor.
397
+
398
+ **Parameters:**
399
+
400
+ model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
401
+ | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
402
+ num_classes : int
403
+ | If model is a gene or cell classifier, specify number of classes it was trained to classify.
404
+ | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
405
+ emb_mode : {"cell", "gene"}
406
+ | Whether to output cell or gene embeddings.
407
+ cell_emb_style : "mean_pool"
408
+ | Method for summarizing cell embeddings.
409
+ | Currently only option is mean pooling of gene embeddings for given cell.
410
+ gene_emb_style : "mean_pool"
411
+ | Method for summarizing gene embeddings.
412
+ | Currently only option is mean pooling of contextual gene embeddings for given gene.
413
+ filter_data : None, dict
414
+ | Default is to extract embeddings from all input data.
415
+ | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
416
+ max_ncells : None, int
417
+ | Maximum number of cells to extract embeddings from.
418
+ | Default is 1000 cells randomly sampled from input data.
419
+ | If None, will extract embeddings from all cells.
420
+ emb_layer : {-1, 0}
421
+ | Embedding layer to extract.
422
+ | The last layer is most specifically weighted to optimize the given learning objective.
423
+ | Generally, it is best to extract the 2nd to last layer to get a more general representation.
424
+ | -1: 2nd to last layer
425
+ | 0: last layer
426
+ emb_label : None, list
427
+ | List of column name(s) in .dataset to add as labels to embedding output.
428
+ labels_to_plot : None, list
429
+ | Cell labels to plot.
430
+ | Shown as color bar in heatmap.
431
+ | Shown as cell color in umap.
432
+ | Plotting umap requires labels to plot.
433
+ forward_batch_size : int
434
+ | Batch size for forward pass.
435
+ nproc : int
436
+ | Number of CPU processes to use.
437
+ summary_stat : {None, "mean", "median", "exact_mean", "exact_median"}
438
+ | If exact_mean or exact_median, outputs only exact mean or median embedding of input data.
439
+ | If mean or median, outputs only approximated mean or median embedding of input data.
440
+ | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
441
+ | Non-exact is slower but more memory-efficient.
442
+ token_dictionary_file : Path
443
+ | Path to pickle file containing token dictionary (Ensembl ID:token).
444
+
445
+ **Examples:**
446
+
447
+ .. code-block :: python
448
+
449
+ >>> from geneformer import EmbExtractor
450
+ >>> embex = EmbExtractor(model_type="CellClassifier",
451
+ ... num_classes=3,
452
+ ... emb_mode="cell",
453
+ ... filter_data={"cell_type":["cardiomyocyte"]},
454
+ ... max_ncells=1000,
455
+ ... max_ncells_to_plot=1000,
456
+ ... emb_layer=-1,
457
+ ... emb_label=["disease", "cell_type"],
458
+ ... labels_to_plot=["disease", "cell_type"])
459
+
460
+ """
461
+
462
+ self.model_type = model_type
463
+ self.num_classes = num_classes
464
+ self.emb_mode = emb_mode
465
+ self.cell_emb_style = cell_emb_style
466
+ self.gene_emb_style = gene_emb_style
467
+ self.filter_data = filter_data
468
+ self.max_ncells = max_ncells
469
+ self.emb_layer = emb_layer
470
+ self.emb_label = emb_label
471
+ self.labels_to_plot = labels_to_plot
472
+ self.forward_batch_size = forward_batch_size
473
+ self.nproc = nproc
474
+ if (summary_stat is not None) and ("exact" in summary_stat):
475
+ self.summary_stat = None
476
+ self.exact_summary_stat = summary_stat
477
+ else:
478
+ self.summary_stat = summary_stat
479
+ self.exact_summary_stat = None
480
+
481
+ self.validate_options()
482
+
483
+ # load token dictionary (Ensembl IDs:token)
484
+ with open(token_dictionary_file, "rb") as f:
485
+ self.gene_token_dict = pickle.load(f)
486
+
487
+ self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
488
+ self.pad_token_id = self.gene_token_dict.get("<pad>")
489
+
490
+ def validate_options(self):
491
+ # confirm arguments are within valid options and compatible with each other
492
+ for attr_name, valid_options in self.valid_option_dict.items():
493
+ attr_value = self.__dict__[attr_name]
494
+ if not isinstance(attr_value, (list, dict)):
495
+ if attr_value in valid_options:
496
+ continue
497
+ valid_type = False
498
+ for option in valid_options:
499
+ if (option in [int, list, dict, bool]) and isinstance(
500
+ attr_value, option
501
+ ):
502
+ valid_type = True
503
+ break
504
+ if valid_type:
505
+ continue
506
+ logger.error(
507
+ f"Invalid option for {attr_name}. "
508
+ f"Valid options for {attr_name}: {valid_options}"
509
+ )
510
+ raise
511
+
512
+ if self.filter_data is not None:
513
+ for key, value in self.filter_data.items():
514
+ if not isinstance(value, list):
515
+ self.filter_data[key] = [value]
516
+ logger.warning(
517
+ "Values in filter_data dict must be lists. "
518
+ f"Changing {key} value to list ([{value}])."
519
+ )
520
+
521
+ def extract_embs(
522
+ self,
523
+ model_directory,
524
+ input_data_file,
525
+ output_directory,
526
+ output_prefix,
527
+ output_torch_embs=False,
528
+ cell_state=None,
529
+ ):
530
+ """
531
+ Extract embeddings from input data and save as results in output_directory.
532
+
533
+ **Parameters:**
534
+
535
+ model_directory : Path
536
+ | Path to directory containing model
537
+ input_data_file : Path
538
+ | Path to directory containing .dataset inputs
539
+ output_directory : Path
540
+ | Path to directory where embedding data will be saved as csv
541
+ output_prefix : str
542
+ | Prefix for output file
543
+ output_torch_embs : bool
544
+ | Whether or not to also output the embeddings as a tensor.
545
+ | Note, if true, will output embeddings as both dataframe and tensor.
546
+ cell_state : dict
547
+ | Cell state key and value for state embedding extraction.
548
+
549
+ **Examples:**
550
+
551
+ .. code-block :: python
552
+
553
+ >>> embs = embex.extract_embs("path/to/model",
554
+ ... "path/to/input_data",
555
+ ... "path/to/output_directory",
556
+ ... "output_prefix")
557
+
558
+ """
559
+
560
+ filtered_input_data = pu.load_and_filter(
561
+ self.filter_data, self.nproc, input_data_file
562
+ )
563
+ if cell_state is not None:
564
+ filtered_input_data = pu.filter_by_dict(
565
+ filtered_input_data, cell_state, self.nproc
566
+ )
567
+ downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
568
+ model = pu.load_model(self.model_type, self.num_classes, model_directory)
569
+ layer_to_quant = pu.quant_layers(model) + self.emb_layer
570
+ embs = get_embs(
571
+ model,
572
+ downsampled_data,
573
+ self.emb_mode,
574
+ layer_to_quant,
575
+ self.pad_token_id,
576
+ self.forward_batch_size,
577
+ self.summary_stat,
578
+ )
579
+
580
+ if self.emb_mode == "cell":
581
+ if self.summary_stat is None:
582
+ embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
583
+ elif self.summary_stat is not None:
584
+ embs_df = pd.DataFrame(embs.cpu().numpy()).T
585
+ elif self.emb_mode == "gene":
586
+ if self.summary_stat is None:
587
+ embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
588
+ elif self.summary_stat is not None:
589
+ embs_df = pd.DataFrame(embs).T
590
+ embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
591
+
592
+ # save embeddings to output_path
593
+ if cell_state is None:
594
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
595
+ embs_df.to_csv(output_path)
596
+
597
+ if self.exact_summary_stat == "exact_mean":
598
+ embs = embs.mean(dim=0)
599
+ embs_df = pd.DataFrame(
600
+ embs_df[0:255].mean(axis="rows"), columns=[self.exact_summary_stat]
601
+ ).T
602
+ elif self.exact_summary_stat == "exact_median":
603
+ embs = torch.median(embs, dim=0)[0]
604
+ embs_df = pd.DataFrame(
605
+ embs_df[0:255].median(axis="rows"), columns=[self.exact_summary_stat]
606
+ ).T
607
+
608
+ if cell_state is not None:
609
+ return embs
610
+ else:
611
+ if output_torch_embs:
612
+ return embs_df, embs
613
+ else:
614
+ return embs_df
615
+
616
+ def get_state_embs(
617
+ self,
618
+ cell_states_to_model,
619
+ model_directory,
620
+ input_data_file,
621
+ output_directory,
622
+ output_prefix,
623
+ output_torch_embs=True,
624
+ ):
625
+ """
626
+ Extract exact mean or exact median cell state embedding positions from input data and save as results in output_directory.
627
+
628
+ **Parameters:**
629
+
630
+ cell_states_to_model : None, dict
631
+ | Cell states to model if testing perturbations that achieve goal state change.
632
+ | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
633
+ | state_key: key specifying name of column in .dataset that defines the start/goal states
634
+ | start_state: value in the state_key column that specifies the start state
635
+ | goal_state: value in the state_key column taht specifies the goal end state
636
+ | alt_states: list of values in the state_key column that specify the alternate end states
637
+ | For example:
638
+ | {"state_key": "disease",
639
+ | "start_state": "dcm",
640
+ | "goal_state": "nf",
641
+ | "alt_states": ["hcm", "other1", "other2"]}
642
+ model_directory : Path
643
+ | Path to directory containing model
644
+ input_data_file : Path
645
+ | Path to directory containing .dataset inputs
646
+ output_directory : Path
647
+ | Path to directory where embedding data will be saved as csv
648
+ output_prefix : str
649
+ | Prefix for output file
650
+ output_torch_embs : bool
651
+ | Whether or not to also output the embeddings as a tensor.
652
+ | Note, if true, will output embeddings as both dataframe and tensor.
653
+
654
+ **Outputs**
655
+
656
+ | Outputs state_embs_dict for use with in silico perturber.
657
+ | Format is dictionary of embedding positions of each cell state to model shifts from/towards.
658
+ | Keys specify each possible cell state to model.
659
+ | Values are target embedding positions as torch.tensor.
660
+ | For example:
661
+ | {"nf": emb_nf,
662
+ | "hcm": emb_hcm,
663
+ | "dcm": emb_dcm,
664
+ | "other1": emb_other1,
665
+ | "other2": emb_other2}
666
+ """
667
+
668
+ pu.validate_cell_states_to_model(cell_states_to_model)
669
+ valid_summary_stats = ["exact_mean", "exact_median"]
670
+ if self.exact_summary_stat not in valid_summary_stats:
671
+ logger.error(
672
+ "For extracting state embs, summary_stat in EmbExtractor "
673
+ f"must be set to option in {valid_summary_stats}"
674
+ )
675
+ raise
676
+
677
+ state_embs_dict = dict()
678
+ state_key = cell_states_to_model["state_key"]
679
+ for k, v in cell_states_to_model.items():
680
+ if k == "state_key":
681
+ continue
682
+ elif (k == "start_state") or (k == "goal_state"):
683
+ state_embs_dict[v] = self.extract_embs(
684
+ model_directory,
685
+ input_data_file,
686
+ output_directory,
687
+ output_prefix,
688
+ output_torch_embs,
689
+ cell_state={state_key: v},
690
+ )
691
+ else: # k == "alt_states"
692
+ for alt_state in v:
693
+ state_embs_dict[alt_state] = self.extract_embs(
694
+ model_directory,
695
+ input_data_file,
696
+ output_directory,
697
+ output_prefix,
698
+ output_torch_embs,
699
+ cell_state={state_key: alt_state},
700
+ )
701
+
702
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".pkl")
703
+ with open(output_path, "wb") as fp:
704
+ pickle.dump(state_embs_dict, fp)
705
+
706
+ return state_embs_dict
707
+
708
+ def plot_embs(
709
+ self,
710
+ embs,
711
+ plot_style,
712
+ output_directory,
713
+ output_prefix,
714
+ max_ncells_to_plot=1000,
715
+ kwargs_dict=None,
716
+ ):
717
+ """
718
+ Plot embeddings, coloring by provided labels.
719
+
720
+ **Parameters:**
721
+
722
+ embs : pandas.core.frame.DataFrame
723
+ | Pandas dataframe containing embeddings output from extract_embs
724
+ plot_style : str
725
+ | Style of plot: "heatmap" or "umap"
726
+ output_directory : Path
727
+ | Path to directory where plots will be saved as pdf
728
+ output_prefix : str
729
+ | Prefix for output file
730
+ max_ncells_to_plot : None, int
731
+ | Maximum number of cells to plot.
732
+ | Default is 1000 cells randomly sampled from embeddings.
733
+ | If None, will plot embeddings from all cells.
734
+ kwargs_dict : dict
735
+ | Dictionary of kwargs to pass to plotting function.
736
+
737
+ **Examples:**
738
+
739
+ .. code-block :: python
740
+
741
+ >>> embex.plot_embs(embs=embs,
742
+ ... plot_style="heatmap",
743
+ ... output_directory="path/to/output_directory",
744
+ ... output_prefix="output_prefix")
745
+
746
+ """
747
+
748
+ if plot_style not in ["heatmap", "umap"]:
749
+ logger.error(
750
+ "Invalid option for 'plot_style'. " "Valid options: {'heatmap','umap'}"
751
+ )
752
+ raise
753
+
754
+ if (plot_style == "umap") and (self.labels_to_plot is None):
755
+ logger.error("Plotting UMAP requires 'labels_to_plot'. ")
756
+ raise
757
+
758
+ if max_ncells_to_plot > self.max_ncells:
759
+ max_ncells_to_plot = self.max_ncells
760
+ logger.warning(
761
+ "max_ncells_to_plot must be <= max_ncells. "
762
+ f"Changing max_ncells_to_plot to {self.max_ncells}."
763
+ )
764
+
765
+ if (max_ncells_to_plot is not None) and (max_ncells_to_plot < self.max_ncells):
766
+ embs = embs.sample(max_ncells_to_plot, axis=0)
767
+
768
+ if self.emb_label is None:
769
+ label_len = 0
770
+ else:
771
+ label_len = len(self.emb_label)
772
+
773
+ emb_dims = embs.shape[1] - label_len
774
+
775
+ if self.emb_label is None:
776
+ emb_labels = None
777
+ else:
778
+ emb_labels = embs.columns[emb_dims:]
779
+
780
+ if plot_style == "umap":
781
+ for label in self.labels_to_plot:
782
+ if label not in emb_labels:
783
+ logger.warning(
784
+ f"Label {label} from labels_to_plot "
785
+ f"not present in provided embeddings dataframe."
786
+ )
787
+ continue
788
+ output_prefix_label = "_" + output_prefix + f"_umap_{label}"
789
+ output_file = (
790
+ Path(output_directory) / output_prefix_label
791
+ ).with_suffix(".pdf")
792
+ plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
793
+
794
+ if plot_style == "heatmap":
795
+ for label in self.labels_to_plot:
796
+ if label not in emb_labels:
797
+ logger.warning(
798
+ f"Label {label} from labels_to_plot "
799
+ f"not present in provided embeddings dataframe."
800
+ )
801
+ continue
802
+ output_prefix_label = output_prefix + f"_heatmap_{label}"
803
+ output_file = (
804
+ Path(output_directory) / output_prefix_label
805
+ ).with_suffix(".pdf")
806
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
geneformer/gene_median_dictionary.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3b589bb5ec75040d05fc44dd6bf0184cf87f3c362cf158d196a6ed3b7fe5f39
3
+ size 940965
geneformer/gene_name_id_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55e67962e79c0039a6c32d43c5c99f38e51964bbcfa32f736150ee1e285c438c
3
+ size 1117117
geneformer/in_silico_perturber.py ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer in silico perturber.
3
+
4
+ **Usage:**
5
+
6
+ .. code-block :: python
7
+
8
+ >>> from geneformer import InSilicoPerturber
9
+ >>> isp = InSilicoPerturber(perturb_type="delete",
10
+ ... perturb_rank_shift=None,
11
+ ... genes_to_perturb="all",
12
+ ... model_type="CellClassifier",
13
+ ... num_classes=0,
14
+ ... emb_mode="cell",
15
+ ... filter_data={"cell_type":["cardiomyocyte"]},
16
+ ... cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
17
+ ... state_embs_dict ={"nf": emb_nf, "hcm": emb_hcm, "dcm": emb_dcm, "other1": emb_other1, "other2": emb_other2},
18
+ ... max_ncells=None,
19
+ ... emb_layer=0,
20
+ ... forward_batch_size=100,
21
+ ... nproc=16)
22
+ >>> isp.perturb_data("path/to/model",
23
+ ... "path/to/input_data",
24
+ ... "path/to/output_directory",
25
+ ... "output_prefix")
26
+
27
+ **Description:**
28
+
29
+ | Performs in silico perturbation (e.g. deletion or overexpression) of defined set of genes or all genes in sample of cells.
30
+ | Outputs impact of perturbation on cell or gene embeddings.
31
+ | Output files are analyzed with ``in_silico_perturber_stats``.
32
+
33
+ """
34
+
35
+ import logging
36
+
37
+ # imports
38
+ import os
39
+ import pickle
40
+ from collections import defaultdict
41
+ from typing import List
42
+
43
+ import seaborn as sns
44
+ import torch
45
+ from datasets import Dataset
46
+ from tqdm.auto import trange
47
+
48
+ from . import perturber_utils as pu
49
+ from .emb_extractor import get_embs
50
+ from .tokenizer import TOKEN_DICTIONARY_FILE
51
+
52
+ sns.set()
53
+
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+
58
+ class InSilicoPerturber:
59
+ valid_option_dict = {
60
+ "perturb_type": {"delete", "overexpress", "inhibit", "activate"},
61
+ "perturb_rank_shift": {None, 1, 2, 3},
62
+ "genes_to_perturb": {"all", list},
63
+ "combos": {0, 1},
64
+ "anchor_gene": {None, str},
65
+ "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
66
+ "num_classes": {int},
67
+ "emb_mode": {"cell", "cell_and_gene"},
68
+ "cell_emb_style": {"mean_pool"},
69
+ "filter_data": {None, dict},
70
+ "cell_states_to_model": {None, dict},
71
+ "state_embs_dict": {None, dict},
72
+ "max_ncells": {None, int},
73
+ "cell_inds_to_perturb": {"all", dict},
74
+ "emb_layer": {-1, 0},
75
+ "forward_batch_size": {int},
76
+ "nproc": {int},
77
+ }
78
+
79
+ def __init__(
80
+ self,
81
+ perturb_type="delete",
82
+ perturb_rank_shift=None,
83
+ genes_to_perturb="all",
84
+ combos=0,
85
+ anchor_gene=None,
86
+ model_type="Pretrained",
87
+ num_classes=0,
88
+ emb_mode="cell",
89
+ cell_emb_style="mean_pool",
90
+ filter_data=None,
91
+ cell_states_to_model=None,
92
+ state_embs_dict=None,
93
+ max_ncells=None,
94
+ cell_inds_to_perturb="all",
95
+ emb_layer=-1,
96
+ forward_batch_size=100,
97
+ nproc=4,
98
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
99
+ ):
100
+ """
101
+ Initialize in silico perturber.
102
+
103
+ **Parameters:**
104
+
105
+ perturb_type : {"delete", "overexpress", "inhibit", "activate"}
106
+ | Type of perturbation.
107
+ | "delete": delete gene from rank value encoding
108
+ | "overexpress": move gene to front of rank value encoding
109
+ | *(TBA)* "inhibit": move gene to lower quartile of rank value encoding
110
+ | *(TBA)* "activate": move gene to higher quartile of rank value encoding
111
+ *(TBA)* perturb_rank_shift : None, {1,2,3}
112
+ | Number of quartiles by which to shift rank of gene.
113
+ | For example, if perturb_type="activate" and perturb_rank_shift=1:
114
+ | genes in 4th quartile will move to middle of 3rd quartile.
115
+ | genes in 3rd quartile will move to middle of 2nd quartile.
116
+ | genes in 2nd quartile will move to middle of 1st quartile.
117
+ | genes in 1st quartile will move to front of rank value encoding.
118
+ | For example, if perturb_type="inhibit" and perturb_rank_shift=2:
119
+ | genes in 1st quartile will move to middle of 3rd quartile.
120
+ | genes in 2nd quartile will move to middle of 4th quartile.
121
+ | genes in 3rd or 4th quartile will move to bottom of rank value encoding.
122
+ genes_to_perturb : "all", list
123
+ | Default is perturbing each gene detected in each cell in the dataset.
124
+ | Otherwise, may provide a list of ENSEMBL IDs of genes to perturb.
125
+ | If gene list is provided, then perturber will only test perturbing them all together
126
+ | (rather than testing each possible combination of the provided genes).
127
+ combos : {0,1}
128
+ | Whether to perturb genes individually (0) or in pairs (1).
129
+ anchor_gene : None, str
130
+ | ENSEMBL ID of gene to use as anchor in combination perturbations.
131
+ | For example, if combos=1 and anchor_gene="ENSG00000148400":
132
+ | anchor gene will be perturbed in combination with each other gene.
133
+ model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
134
+ | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
135
+ num_classes : int
136
+ | If model is a gene or cell classifier, specify number of classes it was trained to classify.
137
+ | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
138
+ emb_mode : {"cell", "cell_and_gene"}
139
+ | Whether to output impact of perturbation on cell and/or gene embeddings.
140
+ | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
141
+ cell_emb_style : "mean_pool"
142
+ | Method for summarizing cell embeddings.
143
+ | Currently only option is mean pooling of gene embeddings for given cell.
144
+ filter_data : None, dict
145
+ | Default is to use all input data for in silico perturbation study.
146
+ | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
147
+ cell_states_to_model : None, dict
148
+ | Cell states to model if testing perturbations that achieve goal state change.
149
+ | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
150
+ | state_key: key specifying name of column in .dataset that defines the start/goal states
151
+ | start_state: value in the state_key column that specifies the start state
152
+ | goal_state: value in the state_key column taht specifies the goal end state
153
+ | alt_states: list of values in the state_key column that specify the alternate end states
154
+ | For example: {"state_key": "disease",
155
+ | "start_state": "dcm",
156
+ | "goal_state": "nf",
157
+ | "alt_states": ["hcm", "other1", "other2"]}
158
+ state_embs_dict : None, dict
159
+ | Embedding positions of each cell state to model shifts from/towards (e.g. mean or median).
160
+ | Dictionary with keys specifying each possible cell state to model.
161
+ | Values are target embedding positions as torch.tensor.
162
+ | For example: {"nf": emb_nf,
163
+ | "hcm": emb_hcm,
164
+ | "dcm": emb_dcm,
165
+ | "other1": emb_other1,
166
+ | "other2": emb_other2}
167
+ max_ncells : None, int
168
+ | Maximum number of cells to test.
169
+ | If None, will test all cells.
170
+ cell_inds_to_perturb : "all", list
171
+ | Default is perturbing each cell in the dataset.
172
+ | Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
173
+ | start_ind: the first index to perturb.
174
+ | end_ind: the last index to perturb (exclusive).
175
+ | Indices will be selected *after* the filter_data criteria and sorting.
176
+ | Useful for splitting extremely large datasets across separate GPUs.
177
+ emb_layer : {-1, 0}
178
+ | Embedding layer to use for quantification.
179
+ | 0: last layer (recommended for questions closely tied to model's training objective)
180
+ | -1: 2nd to last layer (recommended for questions requiring more general representations)
181
+ forward_batch_size : int
182
+ | Batch size for forward pass.
183
+ nproc : int
184
+ | Number of CPU processes to use.
185
+ token_dictionary_file : Path
186
+ | Path to pickle file containing token dictionary (Ensembl ID:token).
187
+ """
188
+
189
+ self.perturb_type = perturb_type
190
+ self.perturb_rank_shift = perturb_rank_shift
191
+ self.genes_to_perturb = genes_to_perturb
192
+ self.combos = combos
193
+ self.anchor_gene = anchor_gene
194
+ if self.genes_to_perturb == "all":
195
+ self.perturb_group = False
196
+ else:
197
+ self.perturb_group = True
198
+ if (self.anchor_gene is not None) or (self.combos != 0):
199
+ self.anchor_gene = None
200
+ self.combos = 0
201
+ logger.warning(
202
+ "anchor_gene set to None and combos set to 0. "
203
+ "If providing list of genes to perturb, "
204
+ "list of genes_to_perturb will be perturbed together, "
205
+ "without anchor gene or combinations."
206
+ )
207
+ self.model_type = model_type
208
+ self.num_classes = num_classes
209
+ self.emb_mode = emb_mode
210
+ self.cell_emb_style = cell_emb_style
211
+ self.filter_data = filter_data
212
+ self.cell_states_to_model = cell_states_to_model
213
+ self.state_embs_dict = state_embs_dict
214
+ self.max_ncells = max_ncells
215
+ self.cell_inds_to_perturb = cell_inds_to_perturb
216
+ self.emb_layer = emb_layer
217
+ self.forward_batch_size = forward_batch_size
218
+ self.nproc = nproc
219
+
220
+ self.validate_options()
221
+
222
+ # load token dictionary (Ensembl IDs:token)
223
+ with open(token_dictionary_file, "rb") as f:
224
+ self.gene_token_dict = pickle.load(f)
225
+
226
+ self.pad_token_id = self.gene_token_dict.get("<pad>")
227
+
228
+ if self.anchor_gene is None:
229
+ self.anchor_token = None
230
+ else:
231
+ try:
232
+ self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
233
+ except KeyError:
234
+ logger.error(f"Anchor gene {self.anchor_gene} not in token dictionary.")
235
+ raise
236
+
237
+ if self.genes_to_perturb == "all":
238
+ self.tokens_to_perturb = "all"
239
+ else:
240
+ missing_genes = [
241
+ gene
242
+ for gene in self.genes_to_perturb
243
+ if gene not in self.gene_token_dict.keys()
244
+ ]
245
+ if len(missing_genes) == len(self.genes_to_perturb):
246
+ logger.error(
247
+ "None of the provided genes to perturb are in token dictionary."
248
+ )
249
+ raise
250
+ elif len(missing_genes) > 0:
251
+ logger.warning(
252
+ f"Genes to perturb {missing_genes} are not in token dictionary."
253
+ )
254
+ self.tokens_to_perturb = [
255
+ self.gene_token_dict.get(gene) for gene in self.genes_to_perturb
256
+ ]
257
+
258
+ def validate_options(self):
259
+ # first disallow options under development
260
+ if self.perturb_type in ["inhibit", "activate"]:
261
+ logger.error(
262
+ "In silico inhibition and activation currently under development. "
263
+ "Current valid options for 'perturb_type': 'delete' or 'overexpress'"
264
+ )
265
+ raise
266
+ if (self.combos > 0) and (self.anchor_token is None):
267
+ logger.error(
268
+ "Combination perturbation without anchor gene is currently under development. "
269
+ "Currently, must provide anchor gene for combination perturbation."
270
+ )
271
+ raise
272
+
273
+ # confirm arguments are within valid options and compatible with each other
274
+ for attr_name, valid_options in self.valid_option_dict.items():
275
+ attr_value = self.__dict__[attr_name]
276
+ if type(attr_value) not in {list, dict}:
277
+ if attr_value in valid_options:
278
+ continue
279
+ if attr_name in ["anchor_gene"]:
280
+ if type(attr_name) in {str}:
281
+ continue
282
+ valid_type = False
283
+ for option in valid_options:
284
+ if (option in [bool, int, list, dict]) and isinstance(
285
+ attr_value, option
286
+ ):
287
+ valid_type = True
288
+ break
289
+ if valid_type:
290
+ continue
291
+ logger.error(
292
+ f"Invalid option for {attr_name}. "
293
+ f"Valid options for {attr_name}: {valid_options}"
294
+ )
295
+ raise
296
+
297
+ if self.perturb_type in ["delete", "overexpress"]:
298
+ if self.perturb_rank_shift is not None:
299
+ if self.perturb_type == "delete":
300
+ logger.warning(
301
+ "perturb_rank_shift set to None. "
302
+ "If perturb type is delete then gene is deleted entirely "
303
+ "rather than shifted by quartile"
304
+ )
305
+ elif self.perturb_type == "overexpress":
306
+ logger.warning(
307
+ "perturb_rank_shift set to None. "
308
+ "If perturb type is overexpress then gene is moved to front "
309
+ "of rank value encoding rather than shifted by quartile"
310
+ )
311
+ self.perturb_rank_shift = None
312
+
313
+ if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
314
+ self.emb_mode = "cell"
315
+ logger.warning(
316
+ "emb_mode set to 'cell'. "
317
+ "Currently, analysis with anchor gene "
318
+ "only outputs effect on cell embeddings."
319
+ )
320
+
321
+ if self.cell_states_to_model is not None:
322
+ pu.validate_cell_states_to_model(self.cell_states_to_model)
323
+
324
+ if self.anchor_gene is not None:
325
+ self.anchor_gene = None
326
+ logger.warning(
327
+ "anchor_gene set to None. "
328
+ "Currently, anchor gene not available "
329
+ "when modeling multiple cell states."
330
+ )
331
+
332
+ if self.state_embs_dict is None:
333
+ logger.error(
334
+ "state_embs_dict must be provided for mode with cell_states_to_model. "
335
+ "Format is dictionary with keys specifying each possible cell state to model. "
336
+ "Values are target embedding positions as torch.tensor."
337
+ )
338
+ raise
339
+
340
+ for state_emb in self.state_embs_dict.values():
341
+ if not torch.is_tensor(state_emb):
342
+ logger.error(
343
+ "state_embs_dict must be dictionary with values being torch.tensor."
344
+ )
345
+ raise
346
+
347
+ keys_absent = []
348
+ for k, v in self.cell_states_to_model.items():
349
+ if (k == "start_state") or (k == "goal_state"):
350
+ if v not in self.state_embs_dict.keys():
351
+ keys_absent.append(v)
352
+ if k == "alt_states":
353
+ for state in v:
354
+ if state not in self.state_embs_dict.keys():
355
+ keys_absent.append(state)
356
+ if len(keys_absent) > 0:
357
+ logger.error(
358
+ "Each start_state, goal_state, and alt_states in cell_states_to_model "
359
+ "must be a key in state_embs_dict with the value being "
360
+ "the state's embedding position as torch.tensor. "
361
+ f"Missing keys: {keys_absent}"
362
+ )
363
+ raise
364
+
365
+ if self.perturb_type in ["inhibit", "activate"]:
366
+ if self.perturb_rank_shift is None:
367
+ logger.error(
368
+ "If perturb_type is inhibit or activate then "
369
+ "quartile to shift by must be specified."
370
+ )
371
+ raise
372
+
373
+ if self.filter_data is not None:
374
+ for key, value in self.filter_data.items():
375
+ if not isinstance(value, list):
376
+ self.filter_data[key] = [value]
377
+ logger.warning(
378
+ "Values in filter_data dict must be lists. "
379
+ f"Changing {key} value to list ([{value}])."
380
+ )
381
+
382
+ if self.cell_inds_to_perturb != "all":
383
+ if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
384
+ logger.error(
385
+ "If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
386
+ )
387
+ raise
388
+ if (
389
+ self.cell_inds_to_perturb["start"] < 0
390
+ or self.cell_inds_to_perturb["end"] < 0
391
+ ):
392
+ logger.error("cell_inds_to_perturb must be positive.")
393
+ raise
394
+
395
+ def perturb_data(
396
+ self, model_directory, input_data_file, output_directory, output_prefix
397
+ ):
398
+ """
399
+ Perturb genes in input data and save as results in output_directory.
400
+
401
+ **Parameters:**
402
+
403
+ model_directory : Path
404
+ | Path to directory containing model
405
+ input_data_file : Path
406
+ | Path to directory containing .dataset inputs
407
+ output_directory : Path
408
+ | Path to directory where perturbation data will be saved as batched pickle files
409
+ output_prefix : str
410
+ | Prefix for output files
411
+ """
412
+
413
+ ### format output path ###
414
+ output_path_prefix = os.path.join(
415
+ output_directory, f"in_silico_{self.perturb_type}_{output_prefix}"
416
+ )
417
+
418
+ ### load model and define parameters ###
419
+ model = pu.load_model(self.model_type, self.num_classes, model_directory)
420
+ self.max_len = pu.get_model_input_size(model)
421
+ layer_to_quant = pu.quant_layers(model) + self.emb_layer
422
+
423
+ ### filter input data ###
424
+ # general filtering of input data based on filter_data argument
425
+ filtered_input_data = pu.load_and_filter(
426
+ self.filter_data, self.nproc, input_data_file
427
+ )
428
+ filtered_input_data = self.apply_additional_filters(filtered_input_data)
429
+
430
+ if self.perturb_group is True:
431
+ self.isp_perturb_set(
432
+ model, filtered_input_data, layer_to_quant, output_path_prefix
433
+ )
434
+ else:
435
+ self.isp_perturb_all(
436
+ model, filtered_input_data, layer_to_quant, output_path_prefix
437
+ )
438
+
439
+ def apply_additional_filters(self, filtered_input_data):
440
+ # additional filtering of input data dependent on isp mode
441
+ if self.cell_states_to_model is not None:
442
+ # filter for cells with start_state and log result
443
+ filtered_input_data = pu.filter_data_by_start_state(
444
+ filtered_input_data, self.cell_states_to_model, self.nproc
445
+ )
446
+
447
+ if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
448
+ # filter for cells with tokens_to_perturb and log result
449
+ filtered_input_data = pu.filter_data_by_tokens_and_log(
450
+ filtered_input_data,
451
+ self.tokens_to_perturb,
452
+ self.nproc,
453
+ "genes_to_perturb",
454
+ )
455
+
456
+ if self.anchor_token is not None:
457
+ # filter for cells with anchor gene and log result
458
+ filtered_input_data = pu.filter_data_by_tokens_and_log(
459
+ filtered_input_data, self.anchor_token, self.nproc, "anchor_gene"
460
+ )
461
+
462
+ # downsample and sort largest to smallest to encounter memory constraints earlier
463
+ filtered_input_data = pu.downsample_and_sort(
464
+ filtered_input_data, self.max_ncells
465
+ )
466
+
467
+ # slice dataset if cells_inds_to_perturb is not "all"
468
+ if self.cell_inds_to_perturb != "all":
469
+ filtered_input_data = pu.slice_by_inds_to_perturb(
470
+ filtered_input_data, self.cell_inds_to_perturb
471
+ )
472
+
473
+ return filtered_input_data
474
+
475
+ def isp_perturb_set(
476
+ self,
477
+ model,
478
+ filtered_input_data: Dataset,
479
+ layer_to_quant: int,
480
+ output_path_prefix: str,
481
+ ):
482
+ def make_group_perturbation_batch(example):
483
+ example_input_ids = example["input_ids"]
484
+ example["tokens_to_perturb"] = self.tokens_to_perturb
485
+ indices_to_perturb = [
486
+ example_input_ids.index(token) if token in example_input_ids else None
487
+ for token in self.tokens_to_perturb
488
+ ]
489
+ indices_to_perturb = [
490
+ item for item in indices_to_perturb if item is not None
491
+ ]
492
+ if len(indices_to_perturb) > 0:
493
+ example["perturb_index"] = indices_to_perturb
494
+ else:
495
+ # -100 indicates tokens to overexpress are not present in rank value encoding
496
+ example["perturb_index"] = [-100]
497
+ if self.perturb_type == "delete":
498
+ example = pu.delete_indices(example)
499
+ elif self.perturb_type == "overexpress":
500
+ example = pu.overexpress_tokens(example, self.max_len)
501
+ example["n_overflow"] = pu.calc_n_overflow(
502
+ self.max_len,
503
+ example["length"],
504
+ self.tokens_to_perturb,
505
+ indices_to_perturb,
506
+ )
507
+ return example
508
+
509
+ total_batch_length = len(filtered_input_data)
510
+ if self.cell_states_to_model is None:
511
+ cos_sims_dict = defaultdict(list)
512
+ else:
513
+ cos_sims_dict = {
514
+ state: defaultdict(list)
515
+ for state in pu.get_possible_states(self.cell_states_to_model)
516
+ }
517
+
518
+ perturbed_data = filtered_input_data.map(
519
+ make_group_perturbation_batch, num_proc=self.nproc
520
+ )
521
+ if self.perturb_type == "overexpress":
522
+ filtered_input_data = filtered_input_data.add_column(
523
+ "n_overflow", perturbed_data["n_overflow"]
524
+ )
525
+ # remove overflow genes from original data so that embeddings are comparable
526
+ # i.e. if original cell has genes 0:2047 and you want to overexpress new gene 2048,
527
+ # then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
528
+ # (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
529
+ # rather than only adding 2048)
530
+ filtered_input_data = filtered_input_data.map(
531
+ pu.truncate_by_n_overflow, num_proc=self.nproc
532
+ )
533
+
534
+ if self.emb_mode == "cell_and_gene":
535
+ stored_gene_embs_dict = defaultdict(list)
536
+
537
+ # iterate through batches
538
+ for i in trange(0, total_batch_length, self.forward_batch_size):
539
+ max_range = min(i + self.forward_batch_size, total_batch_length)
540
+ inds_select = [i for i in range(i, max_range)]
541
+
542
+ minibatch = filtered_input_data.select(inds_select)
543
+ perturbation_batch = perturbed_data.select(inds_select)
544
+
545
+ if self.cell_emb_style == "mean_pool":
546
+ full_original_emb = get_embs(
547
+ model,
548
+ minibatch,
549
+ "gene",
550
+ layer_to_quant,
551
+ self.pad_token_id,
552
+ self.forward_batch_size,
553
+ summary_stat=None,
554
+ silent=True,
555
+ )
556
+ indices_to_perturb = perturbation_batch["perturb_index"]
557
+ # remove indices that were perturbed
558
+ original_emb = pu.remove_perturbed_indices_set(
559
+ full_original_emb,
560
+ self.perturb_type,
561
+ indices_to_perturb,
562
+ self.tokens_to_perturb,
563
+ minibatch["length"],
564
+ )
565
+ full_perturbation_emb = get_embs(
566
+ model,
567
+ perturbation_batch,
568
+ "gene",
569
+ layer_to_quant,
570
+ self.pad_token_id,
571
+ self.forward_batch_size,
572
+ summary_stat=None,
573
+ silent=True,
574
+ )
575
+
576
+ # remove overexpressed genes
577
+ if self.perturb_type == "overexpress":
578
+ perturbation_emb = full_perturbation_emb[
579
+ :, len(self.tokens_to_perturb) :, :
580
+ ]
581
+
582
+ elif self.perturb_type == "delete":
583
+ perturbation_emb = full_perturbation_emb[
584
+ :, : max(perturbation_batch["length"]), :
585
+ ]
586
+
587
+ n_perturbation_genes = perturbation_emb.size()[1]
588
+
589
+ # if no goal states, the cosine similarties are the mean of gene cosine similarities
590
+ if (
591
+ self.cell_states_to_model is None
592
+ or self.emb_mode == "cell_and_gene"
593
+ ):
594
+ gene_cos_sims = pu.quant_cos_sims(
595
+ perturbation_emb,
596
+ original_emb,
597
+ self.cell_states_to_model,
598
+ self.state_embs_dict,
599
+ emb_mode="gene",
600
+ )
601
+
602
+ # if there are goal states, the cosine similarities are the cell cosine similarities
603
+ if self.cell_states_to_model is not None:
604
+ original_cell_emb = pu.mean_nonpadding_embs(
605
+ full_original_emb,
606
+ torch.tensor(minibatch["length"], device="cuda"),
607
+ dim=1,
608
+ )
609
+ perturbation_cell_emb = pu.mean_nonpadding_embs(
610
+ full_perturbation_emb,
611
+ torch.tensor(perturbation_batch["length"], device="cuda"),
612
+ dim=1,
613
+ )
614
+ cell_cos_sims = pu.quant_cos_sims(
615
+ perturbation_cell_emb,
616
+ original_cell_emb,
617
+ self.cell_states_to_model,
618
+ self.state_embs_dict,
619
+ emb_mode="cell",
620
+ )
621
+
622
+ # get cosine similarities in gene embeddings
623
+ # if getting gene embeddings, need gene names
624
+ if self.emb_mode == "cell_and_gene":
625
+ gene_list = minibatch["input_ids"]
626
+ # need to truncate gene_list
627
+ gene_list = [
628
+ [g for g in genes if g not in self.tokens_to_perturb][
629
+ :n_perturbation_genes
630
+ ]
631
+ for genes in gene_list
632
+ ]
633
+
634
+ for cell_i, genes in enumerate(gene_list):
635
+ for gene_j, affected_gene in enumerate(genes):
636
+ if len(self.genes_to_perturb) > 1:
637
+ tokens_to_perturb = tuple(self.tokens_to_perturb)
638
+ else:
639
+ tokens_to_perturb = self.tokens_to_perturb[0]
640
+
641
+ # fill in the gene cosine similarities
642
+ try:
643
+ stored_gene_embs_dict[
644
+ (tokens_to_perturb, affected_gene)
645
+ ].append(gene_cos_sims[cell_i, gene_j].item())
646
+ except KeyError:
647
+ stored_gene_embs_dict[
648
+ (tokens_to_perturb, affected_gene)
649
+ ] = gene_cos_sims[cell_i, gene_j].item()
650
+ else:
651
+ gene_list = None
652
+
653
+ if self.cell_states_to_model is None:
654
+ # calculate the mean of the gene cosine similarities for cell shift
655
+ # tensor of nonpadding lengths for each cell
656
+ if self.perturb_type == "overexpress":
657
+ # subtract number of genes that were overexpressed
658
+ # since they are removed before getting cos sims
659
+ n_overexpressed = len(self.tokens_to_perturb)
660
+ nonpadding_lens = [
661
+ x - n_overexpressed for x in perturbation_batch["length"]
662
+ ]
663
+ else:
664
+ nonpadding_lens = perturbation_batch["length"]
665
+ cos_sims_data = pu.mean_nonpadding_embs(
666
+ gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
667
+ )
668
+ cos_sims_dict = self.update_perturbation_dictionary(
669
+ cos_sims_dict,
670
+ cos_sims_data,
671
+ filtered_input_data,
672
+ indices_to_perturb,
673
+ gene_list,
674
+ )
675
+ else:
676
+ cos_sims_data = cell_cos_sims
677
+ for state in cos_sims_dict.keys():
678
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
679
+ cos_sims_dict[state],
680
+ cos_sims_data[state],
681
+ filtered_input_data,
682
+ indices_to_perturb,
683
+ gene_list,
684
+ )
685
+ del minibatch
686
+ del perturbation_batch
687
+ del original_emb
688
+ del perturbation_emb
689
+ del cos_sims_data
690
+
691
+ torch.cuda.empty_cache()
692
+
693
+ pu.write_perturbation_dictionary(
694
+ cos_sims_dict,
695
+ f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
696
+ )
697
+
698
+ if self.emb_mode == "cell_and_gene":
699
+ pu.write_perturbation_dictionary(
700
+ stored_gene_embs_dict,
701
+ f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
702
+ )
703
+
704
+ def isp_perturb_all(
705
+ self,
706
+ model,
707
+ filtered_input_data: Dataset,
708
+ layer_to_quant: int,
709
+ output_path_prefix: str,
710
+ ):
711
+ pickle_batch = -1
712
+ if self.cell_states_to_model is None:
713
+ cos_sims_dict = defaultdict(list)
714
+ else:
715
+ cos_sims_dict = {
716
+ state: defaultdict(list)
717
+ for state in pu.get_possible_states(self.cell_states_to_model)
718
+ }
719
+
720
+ if self.emb_mode == "cell_and_gene":
721
+ stored_gene_embs_dict = defaultdict(list)
722
+ for i in trange(len(filtered_input_data)):
723
+ example_cell = filtered_input_data.select([i])
724
+ full_original_emb = get_embs(
725
+ model,
726
+ example_cell,
727
+ "gene",
728
+ layer_to_quant,
729
+ self.pad_token_id,
730
+ self.forward_batch_size,
731
+ summary_stat=None,
732
+ silent=True,
733
+ )
734
+
735
+ # gene_list is used to assign cos sims back to genes
736
+ # need to remove the anchor gene
737
+ gene_list = example_cell["input_ids"][0][:]
738
+ if self.anchor_token is not None:
739
+ for token in self.anchor_token:
740
+ gene_list.remove(token)
741
+
742
+ perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
743
+ example_cell,
744
+ self.perturb_type,
745
+ self.tokens_to_perturb,
746
+ self.anchor_token,
747
+ self.combos,
748
+ self.nproc,
749
+ )
750
+
751
+ full_perturbation_emb = get_embs(
752
+ model,
753
+ perturbation_batch,
754
+ "gene",
755
+ layer_to_quant,
756
+ self.pad_token_id,
757
+ self.forward_batch_size,
758
+ summary_stat=None,
759
+ silent=True,
760
+ )
761
+
762
+ num_inds_perturbed = 1 + self.combos
763
+ # need to remove overexpressed gene to quantify cosine shifts
764
+ if self.perturb_type == "overexpress":
765
+ perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
766
+ gene_list = gene_list[
767
+ num_inds_perturbed:
768
+ ] # index 0 is not overexpressed
769
+
770
+ elif self.perturb_type == "delete":
771
+ perturbation_emb = full_perturbation_emb
772
+
773
+ original_batch = pu.make_comparison_batch(
774
+ full_original_emb, indices_to_perturb, perturb_group=False
775
+ )
776
+
777
+ if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
778
+ gene_cos_sims = pu.quant_cos_sims(
779
+ perturbation_emb,
780
+ original_batch,
781
+ self.cell_states_to_model,
782
+ self.state_embs_dict,
783
+ emb_mode="gene",
784
+ )
785
+ if self.cell_states_to_model is not None:
786
+ original_cell_emb = pu.compute_nonpadded_cell_embedding(
787
+ full_original_emb, "mean_pool"
788
+ )
789
+ perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
790
+ full_perturbation_emb, "mean_pool"
791
+ )
792
+
793
+ cell_cos_sims = pu.quant_cos_sims(
794
+ perturbation_cell_emb,
795
+ original_cell_emb,
796
+ self.cell_states_to_model,
797
+ self.state_embs_dict,
798
+ emb_mode="cell",
799
+ )
800
+
801
+ if self.emb_mode == "cell_and_gene":
802
+ # remove perturbed index for gene list
803
+ perturbed_gene_dict = {
804
+ gene: gene_list[:i] + gene_list[i + 1 :]
805
+ for i, gene in enumerate(gene_list)
806
+ }
807
+
808
+ for perturbation_i, perturbed_gene in enumerate(gene_list):
809
+ for gene_j, affected_gene in enumerate(
810
+ perturbed_gene_dict[perturbed_gene]
811
+ ):
812
+ try:
813
+ stored_gene_embs_dict[
814
+ (perturbed_gene, affected_gene)
815
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
816
+ except KeyError:
817
+ stored_gene_embs_dict[
818
+ (perturbed_gene, affected_gene)
819
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
820
+
821
+ if self.cell_states_to_model is None:
822
+ cos_sims_data = torch.mean(gene_cos_sims, dim=1)
823
+ cos_sims_dict = self.update_perturbation_dictionary(
824
+ cos_sims_dict,
825
+ cos_sims_data,
826
+ filtered_input_data,
827
+ indices_to_perturb,
828
+ gene_list,
829
+ )
830
+ else:
831
+ cos_sims_data = cell_cos_sims
832
+ for state in cos_sims_dict.keys():
833
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
834
+ cos_sims_dict[state],
835
+ cos_sims_data[state],
836
+ filtered_input_data,
837
+ indices_to_perturb,
838
+ gene_list,
839
+ )
840
+
841
+ # save dict to disk every 100 cells
842
+ if i % 100 == 0:
843
+ pu.write_perturbation_dictionary(
844
+ cos_sims_dict,
845
+ f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
846
+ )
847
+ if self.emb_mode == "cell_and_gene":
848
+ pu.write_perturbation_dictionary(
849
+ stored_gene_embs_dict,
850
+ f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
851
+ )
852
+
853
+ # reset and clear memory every 1000 cells
854
+ if i % 1000 == 0:
855
+ pickle_batch += 1
856
+ if self.cell_states_to_model is None:
857
+ cos_sims_dict = defaultdict(list)
858
+ else:
859
+ cos_sims_dict = {
860
+ state: defaultdict(list)
861
+ for state in pu.get_possible_states(self.cell_states_to_model)
862
+ }
863
+
864
+ if self.emb_mode == "cell_and_gene":
865
+ stored_gene_embs_dict = defaultdict(list)
866
+
867
+ torch.cuda.empty_cache()
868
+
869
+ pu.write_perturbation_dictionary(
870
+ cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
871
+ )
872
+
873
+ if self.emb_mode == "cell_and_gene":
874
+ pu.write_perturbation_dictionary(
875
+ stored_gene_embs_dict,
876
+ f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
877
+ )
878
+
879
+ def update_perturbation_dictionary(
880
+ self,
881
+ cos_sims_dict: defaultdict,
882
+ cos_sims_data: torch.Tensor,
883
+ filtered_input_data: Dataset,
884
+ indices_to_perturb: List[List[int]],
885
+ gene_list=None,
886
+ ):
887
+ if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
888
+ logger.error(
889
+ f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
890
+ cos_sims_data.shape[0] = {cos_sims_data.shape[0]}.\n \
891
+ len(gene_list) = {len(gene_list)}."
892
+ )
893
+ raise
894
+
895
+ if self.perturb_group is True:
896
+ if len(self.tokens_to_perturb) > 1:
897
+ perturbed_genes = tuple(self.tokens_to_perturb)
898
+ else:
899
+ perturbed_genes = self.tokens_to_perturb[0]
900
+
901
+ # if cell embeddings, can just append
902
+ # shape will be (batch size, 1)
903
+ cos_sims_data = torch.squeeze(cos_sims_data).tolist()
904
+
905
+ # handle case of single cell left
906
+ if not isinstance(cos_sims_data, list):
907
+ cos_sims_data = [cos_sims_data]
908
+
909
+ cos_sims_dict[(perturbed_genes, "cell_emb")] += cos_sims_data
910
+
911
+ else:
912
+ for i, cos in enumerate(cos_sims_data.tolist()):
913
+ cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
914
+
915
+ return cos_sims_dict
geneformer/in_silico_perturber_stats.py ADDED
@@ -0,0 +1,1042 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer in silico perturber stats generator.
3
+
4
+ **Usage:**
5
+
6
+ .. code-block :: python
7
+
8
+ >>> from geneformer import InSilicoPerturberStats
9
+ >>> ispstats = InSilicoPerturberStats(mode="goal_state_shift",
10
+ ... cell_states_to_model={"state_key": "disease",
11
+ ... "start_state": "dcm",
12
+ ... "goal_state": "nf",
13
+ ... "alt_states": ["hcm", "other1", "other2"]})
14
+ >>> ispstats.get_stats("path/to/input_data",
15
+ ... None,
16
+ ... "path/to/output_directory",
17
+ ... "output_prefix")
18
+
19
+ **Description:**
20
+
21
+ | Aggregates data or calculates stats for in silico perturbations based on type of statistics specified in InSilicoPerturberStats.
22
+ | Input data is raw in silico perturbation results in the form of dictionaries outputted by ``in_silico_perturber``.
23
+
24
+ """
25
+
26
+
27
+ import logging
28
+ import os
29
+ import pickle
30
+ import random
31
+ from pathlib import Path
32
+
33
+ import numpy as np
34
+ import pandas as pd
35
+ import statsmodels.stats.multitest as smt
36
+ from scipy.stats import ranksums
37
+ from sklearn.mixture import GaussianMixture
38
+ from tqdm.auto import tqdm, trange
39
+
40
+ from .perturber_utils import flatten_list, validate_cell_states_to_model
41
+ from .tokenizer import TOKEN_DICTIONARY_FILE
42
+
43
+ GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ # invert dictionary keys/values
49
+ def invert_dict(dictionary):
50
+ return {v: k for k, v in dictionary.items()}
51
+
52
+
53
+ def read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token):
54
+ if cell_or_gene_emb == "cell":
55
+ cell_emb_dict = {
56
+ k: v for k, v in cos_sims_dict.items() if v and "cell_emb" in k
57
+ }
58
+ return [cell_emb_dict]
59
+ elif cell_or_gene_emb == "gene":
60
+ if anchor_token is None:
61
+ gene_emb_dict = {k: v for k, v in cos_sims_dict.items() if v}
62
+ else:
63
+ gene_emb_dict = {
64
+ k: v for k, v in cos_sims_dict.items() if v and anchor_token == k[0]
65
+ }
66
+ return [gene_emb_dict]
67
+
68
+
69
+ # read raw dictionary files
70
+ def read_dictionaries(
71
+ input_data_directory,
72
+ cell_or_gene_emb,
73
+ anchor_token,
74
+ cell_states_to_model,
75
+ pickle_suffix,
76
+ ):
77
+ file_found = False
78
+ file_path_list = []
79
+ if cell_states_to_model is None:
80
+ dict_list = []
81
+ else:
82
+ validate_cell_states_to_model(cell_states_to_model)
83
+ cell_states_to_model_valid = {
84
+ state: value
85
+ for state, value in cell_states_to_model.items()
86
+ if state != "state_key"
87
+ and cell_states_to_model[state] is not None
88
+ and cell_states_to_model[state] != []
89
+ }
90
+ cell_states_list = []
91
+ # flatten all state values into list
92
+ for state in cell_states_to_model_valid:
93
+ value = cell_states_to_model_valid[state]
94
+ if isinstance(value, list):
95
+ cell_states_list += value
96
+ else:
97
+ cell_states_list.append(value)
98
+ state_dict = {state_value: dict() for state_value in cell_states_list}
99
+ for file in os.listdir(input_data_directory):
100
+ # process only files with given suffix (e.g. "_raw.pickle")
101
+ if file.endswith(pickle_suffix):
102
+ file_found = True
103
+ file_path_list += [f"{input_data_directory}/{file}"]
104
+ for file_path in tqdm(file_path_list):
105
+ with open(file_path, "rb") as fp:
106
+ cos_sims_dict = pickle.load(fp)
107
+ if cell_states_to_model is None:
108
+ dict_list += read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token)
109
+ else:
110
+ for state_value in cell_states_list:
111
+ new_dict = read_dict(
112
+ cos_sims_dict[state_value], cell_or_gene_emb, anchor_token
113
+ )[0]
114
+ for key in new_dict:
115
+ try:
116
+ state_dict[state_value][key] += new_dict[key]
117
+ except KeyError:
118
+ state_dict[state_value][key] = new_dict[key]
119
+ if not file_found:
120
+ logger.error(
121
+ "No raw data for processing found within provided directory. "
122
+ "Please ensure data files end with '{pickle_suffix}'."
123
+ )
124
+ raise
125
+ if cell_states_to_model is None:
126
+ return dict_list
127
+ else:
128
+ return state_dict
129
+
130
+
131
+ # get complete gene list
132
+ def get_gene_list(dict_list, mode):
133
+ if mode == "cell":
134
+ position = 0
135
+ elif mode == "gene":
136
+ position = 1
137
+ gene_set = set()
138
+ if isinstance(dict_list, list):
139
+ for dict_i in dict_list:
140
+ gene_set.update([k[position] for k, v in dict_i.items() if v])
141
+ elif isinstance(dict_list, dict):
142
+ for state, dict_i in dict_list.items():
143
+ gene_set.update([k[position] for k, v in dict_i.items() if v])
144
+ else:
145
+ logger.error(
146
+ "dict_list should be a list, or if modeling shift to goal states, a dict. "
147
+ f"{type(dict_list)} is not the correct format."
148
+ )
149
+ raise
150
+ gene_list = list(gene_set)
151
+ if mode == "gene":
152
+ gene_list.remove("cell_emb")
153
+ gene_list.sort()
154
+ return gene_list
155
+
156
+
157
+ def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
158
+ try:
159
+ return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
160
+ except TypeError:
161
+ return gene_token_id_dict.get(token_tuple, np.nan)
162
+
163
+
164
+ def n_detections(token, dict_list, mode, anchor_token):
165
+ cos_sim_megalist = []
166
+ for dict_i in dict_list:
167
+ if mode == "cell":
168
+ cos_sim_megalist += dict_i.get((token, "cell_emb"), [])
169
+ elif mode == "gene":
170
+ cos_sim_megalist += dict_i.get((anchor_token, token), [])
171
+ return len(cos_sim_megalist)
172
+
173
+
174
+ def get_fdr(pvalues):
175
+ return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
176
+
177
+
178
+ def get_impact_component(test_value, gaussian_mixture_model):
179
+ impact_border = gaussian_mixture_model.means_[0][0]
180
+ nonimpact_border = gaussian_mixture_model.means_[1][0]
181
+ if test_value > nonimpact_border:
182
+ impact_component = 0
183
+ elif test_value < impact_border:
184
+ impact_component = 1
185
+ else:
186
+ impact_component_raw = gaussian_mixture_model.predict([[test_value]])[0]
187
+ if impact_component_raw == 1:
188
+ impact_component = 0
189
+ elif impact_component_raw == 0:
190
+ impact_component = 1
191
+ return impact_component
192
+
193
+
194
+ # aggregate data for single perturbation in multiple cells
195
+ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
196
+ names = ["Cosine_shift"]
197
+ cos_sims_full_df = pd.DataFrame(columns=names)
198
+
199
+ cos_shift_data = []
200
+ token = cos_sims_df["Gene"][0]
201
+ for dict_i in dict_list:
202
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
203
+ cos_sims_full_df["Cosine_shift"] = cos_shift_data
204
+ return cos_sims_full_df
205
+
206
+
207
+ def find(variable, x):
208
+ try:
209
+ if x in variable: # Test if variable is iterable and contains x
210
+ return True
211
+ except (ValueError, TypeError):
212
+ return x == variable # Test if variable is x if non-iterable
213
+
214
+
215
+ def isp_aggregate_gene_shifts(
216
+ cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict
217
+ ):
218
+ cos_shift_data = dict()
219
+ for i in trange(cos_sims_df.shape[0]):
220
+ token = cos_sims_df["Gene"][i]
221
+ for dict_i in dict_list:
222
+ affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
223
+ for key in affected_pairs:
224
+ if key in cos_shift_data.keys():
225
+ cos_shift_data[key] += dict_i.get(key, [])
226
+ else:
227
+ cos_shift_data[key] = dict_i.get(key, [])
228
+
229
+ cos_data_mean = {
230
+ k: [np.mean(v), np.std(v), len(v)] for k, v in cos_shift_data.items()
231
+ }
232
+ cos_sims_full_df = pd.DataFrame()
233
+ cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
234
+ cos_sims_full_df["Gene_name"] = [
235
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"][0]
236
+ for k, v in cos_data_mean.items()
237
+ ]
238
+ cos_sims_full_df["Ensembl_ID"] = [
239
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"][0]
240
+ for k, v in cos_data_mean.items()
241
+ ]
242
+
243
+ cos_sims_full_df["Affected"] = [k[1] for k, v in cos_data_mean.items()]
244
+ cos_sims_full_df["Affected_gene_name"] = [
245
+ gene_id_name_dict.get(gene_token_id_dict.get(token, np.nan), np.nan)
246
+ for token in cos_sims_full_df["Affected"]
247
+ ]
248
+ cos_sims_full_df["Affected_Ensembl_ID"] = [
249
+ gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
250
+ ]
251
+ cos_sims_full_df["Cosine_shift_mean"] = [v[0] for k, v in cos_data_mean.items()]
252
+ cos_sims_full_df["Cosine_shift_stdev"] = [v[1] for k, v in cos_data_mean.items()]
253
+ cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
254
+
255
+ specific_val = "cell_emb"
256
+ cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
257
+ # reorder so cell embs are at the top and all are subordered by magnitude of cosine shift
258
+ cos_sims_full_df = cos_sims_full_df.sort_values(
259
+ by=(["temp", "Cosine_shift_mean"]), ascending=[False, False]
260
+ ).drop("temp", axis=1)
261
+
262
+ return cos_sims_full_df
263
+
264
+
265
+ # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
266
+ def isp_stats_to_goal_state(
267
+ cos_sims_df, result_dict, cell_states_to_model, genes_perturbed
268
+ ):
269
+ if (
270
+ ("alt_states" not in cell_states_to_model.keys())
271
+ or (len(cell_states_to_model["alt_states"]) == 0)
272
+ or (cell_states_to_model["alt_states"] == [None])
273
+ ):
274
+ alt_end_state_exists = False
275
+ elif (len(cell_states_to_model["alt_states"]) > 0) and (
276
+ cell_states_to_model["alt_states"] != [None]
277
+ ):
278
+ alt_end_state_exists = True
279
+
280
+ # for single perturbation in multiple cells, there are no random perturbations to compare to
281
+ if genes_perturbed != "all":
282
+ cos_sims_full_df = pd.DataFrame()
283
+
284
+ cos_shift_data_end = []
285
+ token = cos_sims_df["Gene"][0]
286
+ cos_shift_data_end += result_dict[cell_states_to_model["goal_state"]].get(
287
+ (token, "cell_emb"), []
288
+ )
289
+ cos_sims_full_df["Shift_to_goal_end"] = [np.mean(cos_shift_data_end)]
290
+ if alt_end_state_exists is True:
291
+ for alt_state in cell_states_to_model["alt_states"]:
292
+ cos_shift_data_alt_state = []
293
+ cos_shift_data_alt_state += result_dict.get(alt_state).get(
294
+ (token, "cell_emb"), []
295
+ )
296
+ cos_sims_full_df[f"Shift_to_alt_end_{alt_state}"] = [
297
+ np.mean(cos_shift_data_alt_state)
298
+ ]
299
+
300
+ # sort by shift to desired state
301
+ cos_sims_full_df = cos_sims_full_df.sort_values(
302
+ by=["Shift_to_goal_end"], ascending=[False]
303
+ )
304
+ return cos_sims_full_df
305
+
306
+ elif genes_perturbed == "all":
307
+ goal_end_random_megalist = []
308
+ if alt_end_state_exists is True:
309
+ alt_end_state_random_dict = {
310
+ alt_state: [] for alt_state in cell_states_to_model["alt_states"]
311
+ }
312
+ for i in trange(cos_sims_df.shape[0]):
313
+ token = cos_sims_df["Gene"][i]
314
+ goal_end_random_megalist += result_dict[
315
+ cell_states_to_model["goal_state"]
316
+ ].get((token, "cell_emb"), [])
317
+ if alt_end_state_exists is True:
318
+ for alt_state in cell_states_to_model["alt_states"]:
319
+ alt_end_state_random_dict[alt_state] += result_dict[alt_state].get(
320
+ (token, "cell_emb"), []
321
+ )
322
+
323
+ # downsample to improve speed of ranksums
324
+ if len(goal_end_random_megalist) > 100_000:
325
+ random.seed(42)
326
+ goal_end_random_megalist = random.sample(
327
+ goal_end_random_megalist, k=100_000
328
+ )
329
+ if alt_end_state_exists is True:
330
+ for alt_state in cell_states_to_model["alt_states"]:
331
+ if len(alt_end_state_random_dict[alt_state]) > 100_000:
332
+ random.seed(42)
333
+ alt_end_state_random_dict[alt_state] = random.sample(
334
+ alt_end_state_random_dict[alt_state], k=100_000
335
+ )
336
+
337
+ names = [
338
+ "Gene",
339
+ "Gene_name",
340
+ "Ensembl_ID",
341
+ "Shift_to_goal_end",
342
+ "Goal_end_vs_random_pval",
343
+ ]
344
+ if alt_end_state_exists is True:
345
+ [
346
+ names.append(f"Shift_to_alt_end_{alt_state}")
347
+ for alt_state in cell_states_to_model["alt_states"]
348
+ ]
349
+ names.append(names.pop(names.index("Goal_end_vs_random_pval")))
350
+ [
351
+ names.append(f"Alt_end_vs_random_pval_{alt_state}")
352
+ for alt_state in cell_states_to_model["alt_states"]
353
+ ]
354
+ cos_sims_full_df = pd.DataFrame(columns=names)
355
+
356
+ n_detections_dict = dict()
357
+ for i in trange(cos_sims_df.shape[0]):
358
+ token = cos_sims_df["Gene"][i]
359
+ name = cos_sims_df["Gene_name"][i]
360
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
361
+ goal_end_cos_sim_megalist = result_dict[
362
+ cell_states_to_model["goal_state"]
363
+ ].get((token, "cell_emb"), [])
364
+ n_detections_dict[token] = len(goal_end_cos_sim_megalist)
365
+ mean_goal_end = np.mean(goal_end_cos_sim_megalist)
366
+ pval_goal_end = ranksums(
367
+ goal_end_random_megalist, goal_end_cos_sim_megalist
368
+ ).pvalue
369
+
370
+ if alt_end_state_exists is True:
371
+ alt_end_state_dict = {
372
+ alt_state: [] for alt_state in cell_states_to_model["alt_states"]
373
+ }
374
+ for alt_state in cell_states_to_model["alt_states"]:
375
+ alt_end_state_dict[alt_state] = result_dict[alt_state].get(
376
+ (token, "cell_emb"), []
377
+ )
378
+ alt_end_state_dict[f"{alt_state}_mean"] = np.mean(
379
+ alt_end_state_dict[alt_state]
380
+ )
381
+ alt_end_state_dict[f"{alt_state}_pval"] = ranksums(
382
+ alt_end_state_random_dict[alt_state],
383
+ alt_end_state_dict[alt_state],
384
+ ).pvalue
385
+
386
+ results_dict = dict()
387
+ results_dict["Gene"] = token
388
+ results_dict["Gene_name"] = name
389
+ results_dict["Ensembl_ID"] = ensembl_id
390
+ results_dict["Shift_to_goal_end"] = mean_goal_end
391
+ results_dict["Goal_end_vs_random_pval"] = pval_goal_end
392
+ if alt_end_state_exists is True:
393
+ for alt_state in cell_states_to_model["alt_states"]:
394
+ results_dict[f"Shift_to_alt_end_{alt_state}"] = alt_end_state_dict[
395
+ f"{alt_state}_mean"
396
+ ]
397
+ results_dict[
398
+ f"Alt_end_vs_random_pval_{alt_state}"
399
+ ] = alt_end_state_dict[f"{alt_state}_pval"]
400
+
401
+ cos_sims_df_i = pd.DataFrame(results_dict, index=[i])
402
+ cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
403
+
404
+ cos_sims_full_df["Goal_end_FDR"] = get_fdr(
405
+ list(cos_sims_full_df["Goal_end_vs_random_pval"])
406
+ )
407
+ if alt_end_state_exists is True:
408
+ for alt_state in cell_states_to_model["alt_states"]:
409
+ cos_sims_full_df[f"Alt_end_FDR_{alt_state}"] = get_fdr(
410
+ list(cos_sims_full_df[f"Alt_end_vs_random_pval_{alt_state}"])
411
+ )
412
+
413
+ # quantify number of detections of each gene
414
+ cos_sims_full_df["N_Detections"] = [
415
+ n_detections_dict[token] for token in cos_sims_full_df["Gene"]
416
+ ]
417
+
418
+ # sort by shift to desired state
419
+ cos_sims_full_df["Sig"] = [
420
+ 1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]
421
+ ]
422
+ cos_sims_full_df = cos_sims_full_df.sort_values(
423
+ by=["Sig", "Shift_to_goal_end", "Goal_end_FDR"],
424
+ ascending=[False, False, True],
425
+ )
426
+
427
+ return cos_sims_full_df
428
+
429
+
430
+ # stats comparing cos sim shifts of test perturbations vs null distribution
431
+ def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
432
+ cos_sims_full_df = cos_sims_df.copy()
433
+
434
+ cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
435
+ cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
436
+ cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(
437
+ cos_sims_df.shape[0], dtype=float
438
+ )
439
+ cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
440
+ cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
441
+ cos_sims_full_df["N_Detections_test"] = np.zeros(
442
+ cos_sims_df.shape[0], dtype="uint32"
443
+ )
444
+ cos_sims_full_df["N_Detections_null"] = np.zeros(
445
+ cos_sims_df.shape[0], dtype="uint32"
446
+ )
447
+
448
+ for i in trange(cos_sims_df.shape[0]):
449
+ token = cos_sims_df["Gene"][i]
450
+ test_shifts = []
451
+ null_shifts = []
452
+
453
+ for dict_i in dict_list:
454
+ test_shifts += dict_i.get((token, "cell_emb"), [])
455
+
456
+ for dict_i in null_dict_list:
457
+ null_shifts += dict_i.get((token, "cell_emb"), [])
458
+
459
+ cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
460
+ cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
461
+ cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(
462
+ test_shifts
463
+ ) - np.mean(null_shifts)
464
+ cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(
465
+ test_shifts, null_shifts, nan_policy="omit"
466
+ ).pvalue
467
+ # remove nan values
468
+ cos_sims_full_df.Test_vs_null_pval = np.where(
469
+ np.isnan(cos_sims_full_df.Test_vs_null_pval),
470
+ 1,
471
+ cos_sims_full_df.Test_vs_null_pval,
472
+ )
473
+ cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
474
+ cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
475
+
476
+ cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(
477
+ cos_sims_full_df["Test_vs_null_pval"]
478
+ )
479
+
480
+ cos_sims_full_df["Sig"] = [
481
+ 1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]
482
+ ]
483
+ cos_sims_full_df = cos_sims_full_df.sort_values(
484
+ by=["Sig", "Test_vs_null_avg_shift", "Test_vs_null_FDR"],
485
+ ascending=[False, False, True],
486
+ )
487
+ return cos_sims_full_df
488
+
489
+
490
+ # stats for identifying perturbations with largest effect within a given set of cells
491
+ # fits a mixture model to 2 components (impact vs. non-impact) and
492
+ # reports the most likely component for each test perturbation
493
+ # Note: because assumes given perturbation has a consistent effect in the cells tested,
494
+ # we recommend only using the mixture model strategy with uniform cell populations
495
+ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
496
+ names = ["Gene", "Gene_name", "Ensembl_ID"]
497
+
498
+ if combos == 0:
499
+ names += ["Test_avg_shift"]
500
+ elif combos == 1:
501
+ names += [
502
+ "Anchor_shift",
503
+ "Test_token_shift",
504
+ "Sum_of_indiv_shifts",
505
+ "Combo_shift",
506
+ "Combo_minus_sum_shift",
507
+ ]
508
+
509
+ names += ["Impact_component", "Impact_component_percent"]
510
+
511
+ cos_sims_full_df = pd.DataFrame(columns=names)
512
+ avg_values = []
513
+ gene_names = []
514
+
515
+ for i in trange(cos_sims_df.shape[0]):
516
+ token = cos_sims_df["Gene"][i]
517
+ name = cos_sims_df["Gene_name"][i]
518
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
519
+ cos_shift_data = []
520
+
521
+ for dict_i in dict_list:
522
+ if (combos == 0) and (anchor_token is not None):
523
+ cos_shift_data += dict_i.get((anchor_token, token), [])
524
+ else:
525
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
526
+
527
+ # Extract values for current gene
528
+ if combos == 0:
529
+ test_values = cos_shift_data
530
+ elif combos == 1:
531
+ test_values = []
532
+ for tup in cos_shift_data:
533
+ test_values.append(tup[2])
534
+
535
+ if len(test_values) > 0:
536
+ avg_value = np.mean(test_values)
537
+ avg_values.append(avg_value)
538
+ gene_names.append(name)
539
+
540
+ # fit Gaussian mixture model to dataset of mean for each gene
541
+ avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
542
+ gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
543
+
544
+ for i in trange(cos_sims_df.shape[0]):
545
+ token = cos_sims_df["Gene"][i]
546
+ name = cos_sims_df["Gene_name"][i]
547
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
548
+ cos_shift_data = []
549
+
550
+ for dict_i in dict_list:
551
+ if (combos == 0) and (anchor_token is not None):
552
+ cos_shift_data += dict_i.get((anchor_token, token), [])
553
+ else:
554
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
555
+
556
+ if combos == 0:
557
+ mean_test = np.mean(cos_shift_data)
558
+ impact_components = [
559
+ get_impact_component(value, gm) for value in cos_shift_data
560
+ ]
561
+ elif combos == 1:
562
+ anchor_cos_sim_megalist = [
563
+ anchor for anchor, token, combo in cos_shift_data
564
+ ]
565
+ token_cos_sim_megalist = [token for anchor, token, combo in cos_shift_data]
566
+ anchor_plus_token_cos_sim_megalist = [
567
+ 1 - ((1 - anchor) + (1 - token))
568
+ for anchor, token, combo in cos_shift_data
569
+ ]
570
+ combo_anchor_token_cos_sim_megalist = [
571
+ combo for anchor, token, combo in cos_shift_data
572
+ ]
573
+ combo_minus_sum_cos_sim_megalist = [
574
+ combo - (1 - ((1 - anchor) + (1 - token)))
575
+ for anchor, token, combo in cos_shift_data
576
+ ]
577
+
578
+ mean_anchor = np.mean(anchor_cos_sim_megalist)
579
+ mean_token = np.mean(token_cos_sim_megalist)
580
+ mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
581
+ mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
582
+ mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
583
+
584
+ impact_components = [
585
+ get_impact_component(value, gm)
586
+ for value in combo_anchor_token_cos_sim_megalist
587
+ ]
588
+
589
+ impact_component = get_impact_component(mean_test, gm)
590
+ impact_component_percent = np.mean(impact_components) * 100
591
+
592
+ data_i = [token, name, ensembl_id]
593
+ if combos == 0:
594
+ data_i += [mean_test]
595
+ elif combos == 1:
596
+ data_i += [
597
+ mean_anchor,
598
+ mean_token,
599
+ mean_sum,
600
+ mean_test,
601
+ mean_combo_minus_sum,
602
+ ]
603
+ data_i += [impact_component, impact_component_percent]
604
+
605
+ cos_sims_df_i = pd.DataFrame(dict(zip(names, data_i)), index=[i])
606
+ cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
607
+
608
+ # quantify number of detections of each gene
609
+ cos_sims_full_df["N_Detections"] = [
610
+ n_detections(i, dict_list, "gene", anchor_token)
611
+ for i in cos_sims_full_df["Gene"]
612
+ ]
613
+
614
+ if combos == 0:
615
+ cos_sims_full_df = cos_sims_full_df.sort_values(
616
+ by=["Impact_component", "Test_avg_shift"], ascending=[False, True]
617
+ )
618
+ elif combos == 1:
619
+ cos_sims_full_df = cos_sims_full_df.sort_values(
620
+ by=["Impact_component", "Combo_minus_sum_shift"], ascending=[False, True]
621
+ )
622
+ return cos_sims_full_df
623
+
624
+
625
+ class InSilicoPerturberStats:
626
+ valid_option_dict = {
627
+ "mode": {
628
+ "goal_state_shift",
629
+ "vs_null",
630
+ "mixture_model",
631
+ "aggregate_data",
632
+ "aggregate_gene_shifts",
633
+ },
634
+ "genes_perturbed": {"all", list},
635
+ "combos": {0, 1},
636
+ "anchor_gene": {None, str},
637
+ "cell_states_to_model": {None, dict},
638
+ "pickle_suffix": {None, str},
639
+ }
640
+
641
+ def __init__(
642
+ self,
643
+ mode="mixture_model",
644
+ genes_perturbed="all",
645
+ combos=0,
646
+ anchor_gene=None,
647
+ cell_states_to_model=None,
648
+ pickle_suffix="_raw.pickle",
649
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
650
+ gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
651
+ ):
652
+ """
653
+ Initialize in silico perturber stats generator.
654
+
655
+ **Parameters:**
656
+
657
+ mode : {"goal_state_shift", "vs_null", "mixture_model", "aggregate_data", "aggregate_gene_shifts"}
658
+ | Type of stats.
659
+ | "goal_state_shift": perturbation vs. random for desired cell state shift
660
+ | "vs_null": perturbation vs. null from provided null distribution dataset
661
+ | "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
662
+ | "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
663
+ | "aggregate_gene_shifts": aggregates cosine shifts of genes in response to perturbation(s)
664
+ genes_perturbed : "all", list
665
+ | Genes perturbed in isp experiment.
666
+ | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
667
+ | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
668
+ combos : {0,1,2}
669
+ | Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
670
+ anchor_gene : None, str
671
+ | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
672
+ | For example, if combos=1 and anchor_gene="ENSG00000136574":
673
+ | analyzes data for anchor gene perturbed in combination with each other gene.
674
+ | However, if combos=0 and anchor_gene="ENSG00000136574":
675
+ | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
676
+ cell_states_to_model: None, dict
677
+ | Cell states to model if testing perturbations that achieve goal state change.
678
+ | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
679
+ | state_key: key specifying name of column in .dataset that defines the start/goal states
680
+ | start_state: value in the state_key column that specifies the start state
681
+ | goal_state: value in the state_key column taht specifies the goal end state
682
+ | alt_states: list of values in the state_key column that specify the alternate end states
683
+ | For example: {"state_key": "disease",
684
+ | "start_state": "dcm",
685
+ | "goal_state": "nf",
686
+ | "alt_states": ["hcm", "other1", "other2"]}
687
+ token_dictionary_file : Path
688
+ | Path to pickle file containing token dictionary (Ensembl ID:token).
689
+ gene_name_id_dictionary_file : Path
690
+ | Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
691
+ """
692
+
693
+ self.mode = mode
694
+ self.genes_perturbed = genes_perturbed
695
+ self.combos = combos
696
+ self.anchor_gene = anchor_gene
697
+ self.cell_states_to_model = cell_states_to_model
698
+ self.pickle_suffix = pickle_suffix
699
+
700
+ self.validate_options()
701
+
702
+ # load token dictionary (Ensembl IDs:token)
703
+ with open(token_dictionary_file, "rb") as f:
704
+ self.gene_token_dict = pickle.load(f)
705
+
706
+ # load gene name dictionary (gene name:Ensembl ID)
707
+ with open(gene_name_id_dictionary_file, "rb") as f:
708
+ self.gene_name_id_dict = pickle.load(f)
709
+
710
+ if anchor_gene is None:
711
+ self.anchor_token = None
712
+ else:
713
+ self.anchor_token = self.gene_token_dict[self.anchor_gene]
714
+
715
+ def validate_options(self):
716
+ for attr_name, valid_options in self.valid_option_dict.items():
717
+ attr_value = self.__dict__[attr_name]
718
+ if type(attr_value) not in {list, dict}:
719
+ if attr_name in {"anchor_gene"}:
720
+ continue
721
+ elif attr_value in valid_options:
722
+ continue
723
+ valid_type = False
724
+ for option in valid_options:
725
+ if (option in [str, int, list, dict]) and isinstance(
726
+ attr_value, option
727
+ ):
728
+ valid_type = True
729
+ break
730
+ if not valid_type:
731
+ logger.error(
732
+ f"Invalid option for {attr_name}. "
733
+ f"Valid options for {attr_name}: {valid_options}"
734
+ )
735
+ raise
736
+
737
+ if self.cell_states_to_model is not None:
738
+ if len(self.cell_states_to_model.items()) == 1:
739
+ logger.warning(
740
+ "The single value dictionary for cell_states_to_model will be "
741
+ "replaced with a dictionary with named keys for start, goal, and alternate states. "
742
+ "Please specify state_key, start_state, goal_state, and alt_states "
743
+ "in the cell_states_to_model dictionary for future use. "
744
+ "For example, cell_states_to_model={"
745
+ "'state_key': 'disease', "
746
+ "'start_state': 'dcm', "
747
+ "'goal_state': 'nf', "
748
+ "'alt_states': ['hcm', 'other1', 'other2']}"
749
+ )
750
+ for key, value in self.cell_states_to_model.items():
751
+ if (len(value) == 3) and isinstance(value, tuple):
752
+ if (
753
+ isinstance(value[0], list)
754
+ and isinstance(value[1], list)
755
+ and isinstance(value[2], list)
756
+ ):
757
+ if len(value[0]) == 1 and len(value[1]) == 1:
758
+ all_values = value[0] + value[1] + value[2]
759
+ if len(all_values) == len(set(all_values)):
760
+ continue
761
+ # reformat to the new named key format
762
+ state_values = flatten_list(list(self.cell_states_to_model.values()))
763
+ self.cell_states_to_model = {
764
+ "state_key": list(self.cell_states_to_model.keys())[0],
765
+ "start_state": state_values[0][0],
766
+ "goal_state": state_values[1][0],
767
+ "alt_states": state_values[2:][0],
768
+ }
769
+ elif set(self.cell_states_to_model.keys()) == {
770
+ "state_key",
771
+ "start_state",
772
+ "goal_state",
773
+ "alt_states",
774
+ }:
775
+ if (
776
+ (self.cell_states_to_model["state_key"] is None)
777
+ or (self.cell_states_to_model["start_state"] is None)
778
+ or (self.cell_states_to_model["goal_state"] is None)
779
+ ):
780
+ logger.error(
781
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
782
+ )
783
+ raise
784
+
785
+ if (
786
+ self.cell_states_to_model["start_state"]
787
+ == self.cell_states_to_model["goal_state"]
788
+ ):
789
+ logger.error("All states must be unique.")
790
+ raise
791
+
792
+ if self.cell_states_to_model["alt_states"] is not None:
793
+ if not isinstance(self.cell_states_to_model["alt_states"], list):
794
+ logger.error(
795
+ "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
796
+ )
797
+ raise
798
+ if len(self.cell_states_to_model["alt_states"]) != len(
799
+ set(self.cell_states_to_model["alt_states"])
800
+ ):
801
+ logger.error("All states must be unique.")
802
+ raise
803
+
804
+ else:
805
+ logger.error(
806
+ "cell_states_to_model must only have the following four keys: "
807
+ "'state_key', 'start_state', 'goal_state', 'alt_states'."
808
+ "For example, cell_states_to_model={"
809
+ "'state_key': 'disease', "
810
+ "'start_state': 'dcm', "
811
+ "'goal_state': 'nf', "
812
+ "'alt_states': ['hcm', 'other1', 'other2']}"
813
+ )
814
+ raise
815
+
816
+ if self.anchor_gene is not None:
817
+ self.anchor_gene = None
818
+ logger.warning(
819
+ "anchor_gene set to None. "
820
+ "Currently, anchor gene not available "
821
+ "when modeling multiple cell states."
822
+ )
823
+
824
+ if self.combos > 0:
825
+ if self.anchor_gene is None:
826
+ logger.error(
827
+ "Currently, stats are only supported for combination "
828
+ "in silico perturbation run with anchor gene. Please add "
829
+ "anchor gene when using with combos > 0. "
830
+ )
831
+ raise
832
+
833
+ if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
834
+ logger.error(
835
+ "Mixture model mode requires multiple gene perturbations to fit model "
836
+ "so is incompatible with a single grouped perturbation."
837
+ )
838
+ raise
839
+ if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
840
+ logger.error(
841
+ "Simple data aggregation mode is for single perturbation in multiple cells "
842
+ "so is incompatible with a genes_perturbed being 'all'."
843
+ )
844
+ raise
845
+
846
+ def get_stats(
847
+ self,
848
+ input_data_directory,
849
+ null_dist_data_directory,
850
+ output_directory,
851
+ output_prefix,
852
+ null_dict_list=None,
853
+ ):
854
+ """
855
+ Get stats for in silico perturbation data and save as results in output_directory.
856
+
857
+ **Parameters:**
858
+
859
+ input_data_directory : Path
860
+ | Path to directory containing cos_sim dictionary inputs
861
+ null_dist_data_directory : Path
862
+ | Path to directory containing null distribution cos_sim dictionary inputs
863
+ output_directory : Path
864
+ | Path to directory where perturbation data will be saved as .csv
865
+ output_prefix : str
866
+ | Prefix for output .csv
867
+ null_dict_list: list[dict]
868
+ | List of loaded null distribution dictionary if more than one comparison vs. the null is to be performed
869
+
870
+ **Outputs:**
871
+
872
+ Definition of possible columns in .csv output file.
873
+
874
+ | Of note, not all columns will be present in all output files.
875
+ | Some columns are specific to particular perturbation modes.
876
+
877
+ | "Gene": gene token
878
+ | "Gene_name": gene name
879
+ | "Ensembl_ID": gene Ensembl ID
880
+ | "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
881
+ | "Sig": 1 if FDR<0.05, otherwise 0
882
+
883
+ | "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
884
+ | "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
885
+ | "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
886
+ | pvalue compares shift caused by perturbing given gene compared to random genes
887
+ | "Alt_end_vs_random_pval": pvalue of cosine shift from start state towards alternate end state by Wilcoxon
888
+ | pvalue compares shift caused by perturbing given gene compared to random genes
889
+ | "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
890
+ | "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
891
+
892
+ | "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
893
+ | "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
894
+ | "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
895
+ | (i.e. "Test_avg_shift" minus "Null_avg_shift")
896
+ | "Test_vs_null_pval": pvalue of cosine shift in test vs. null distribution
897
+ | "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
898
+ | "N_Detections_test": "N_Detections" in cells from test distribution
899
+ | "N_Detections_null": "N_Detections" in cells from null distribution
900
+
901
+ | "Anchor_shift": cosine shift in response to given perturbation of anchor gene
902
+ | "Test_token_shift": cosine shift in response to given perturbation of test gene
903
+ | "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
904
+ | "Combo_shift": cosine shift in response to given perturbation of both anchor and test gene(s) in combination
905
+ | "Combo_minus_sum_shift": difference of cosine shifts in response combo perturbation vs. sum of individual perturbations
906
+ | (i.e. "Combo_shift" minus "Sum_of_indiv_shifts")
907
+ | "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
908
+ | 1: within impact component; 0: not within impact component
909
+ | "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
910
+
911
+ | In case of aggregating gene shifts:
912
+ | "Perturbed": ID(s) of gene(s) being perturbed
913
+ | "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
914
+ | "Cosine_shift_mean": mean of cosine shift of modeled perturbation on affected gene or cell
915
+ | "Cosine_shift_stdev": standard deviation of cosine shift of modeled perturbation on affected gene or cell
916
+ """
917
+
918
+ if self.mode not in [
919
+ "goal_state_shift",
920
+ "vs_null",
921
+ "mixture_model",
922
+ "aggregate_data",
923
+ "aggregate_gene_shifts",
924
+ ]:
925
+ logger.error(
926
+ "Currently, only modes available are stats for goal_state_shift, "
927
+ "vs_null (comparing to null distribution), "
928
+ "mixture_model (fitting mixture model for perturbations with or without impact), "
929
+ "and aggregating data for single perturbations or for gene embedding shifts."
930
+ )
931
+ raise
932
+
933
+ self.gene_token_id_dict = invert_dict(self.gene_token_dict)
934
+ self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
935
+
936
+ # obtain total gene list
937
+ if (self.combos == 0) and (self.anchor_token is not None):
938
+ # cos sim data for effect of gene perturbation on the embedding of each other gene
939
+ dict_list = read_dictionaries(
940
+ input_data_directory,
941
+ "gene",
942
+ self.anchor_token,
943
+ self.cell_states_to_model,
944
+ self.pickle_suffix,
945
+ )
946
+ gene_list = get_gene_list(dict_list, "gene")
947
+ elif (
948
+ (self.combos == 0)
949
+ and (self.anchor_token is None)
950
+ and (self.mode == "aggregate_gene_shifts")
951
+ ):
952
+ dict_list = read_dictionaries(
953
+ input_data_directory,
954
+ "gene",
955
+ self.anchor_token,
956
+ self.cell_states_to_model,
957
+ self.pickle_suffix,
958
+ )
959
+ gene_list = get_gene_list(dict_list, "cell")
960
+ else:
961
+ # cos sim data for effect of gene perturbation on the embedding of each cell
962
+ dict_list = read_dictionaries(
963
+ input_data_directory,
964
+ "cell",
965
+ self.anchor_token,
966
+ self.cell_states_to_model,
967
+ self.pickle_suffix,
968
+ )
969
+ gene_list = get_gene_list(dict_list, "cell")
970
+
971
+ # initiate results dataframe
972
+ cos_sims_df_initial = pd.DataFrame(
973
+ {
974
+ "Gene": gene_list,
975
+ "Gene_name": [self.token_to_gene_name(item) for item in gene_list],
976
+ "Ensembl_ID": [
977
+ token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict)
978
+ if self.genes_perturbed != "all"
979
+ else self.gene_token_id_dict[genes[1]]
980
+ if isinstance(genes, tuple)
981
+ else self.gene_token_id_dict[genes]
982
+ for genes in gene_list
983
+ ],
984
+ },
985
+ index=[i for i in range(len(gene_list))],
986
+ )
987
+
988
+ if self.mode == "goal_state_shift":
989
+ cos_sims_df = isp_stats_to_goal_state(
990
+ cos_sims_df_initial,
991
+ dict_list,
992
+ self.cell_states_to_model,
993
+ self.genes_perturbed,
994
+ )
995
+
996
+ elif self.mode == "vs_null":
997
+ if null_dict_list is None:
998
+ null_dict_list = read_dictionaries(
999
+ null_dist_data_directory,
1000
+ "cell",
1001
+ self.anchor_token,
1002
+ self.cell_states_to_model,
1003
+ self.pickle_suffix,
1004
+ )
1005
+ cos_sims_df = isp_stats_vs_null(
1006
+ cos_sims_df_initial, dict_list, null_dict_list
1007
+ )
1008
+
1009
+ elif self.mode == "mixture_model":
1010
+ cos_sims_df = isp_stats_mixture_model(
1011
+ cos_sims_df_initial, dict_list, self.combos, self.anchor_token
1012
+ )
1013
+
1014
+ elif self.mode == "aggregate_data":
1015
+ cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
1016
+
1017
+ elif self.mode == "aggregate_gene_shifts":
1018
+ cos_sims_df = isp_aggregate_gene_shifts(
1019
+ cos_sims_df_initial,
1020
+ dict_list,
1021
+ self.gene_token_id_dict,
1022
+ self.gene_id_name_dict,
1023
+ )
1024
+
1025
+ # save perturbation stats to output_path
1026
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
1027
+ cos_sims_df.to_csv(output_path)
1028
+
1029
+ def token_to_gene_name(self, item):
1030
+ if np.issubdtype(type(item), np.integer):
1031
+ return self.gene_id_name_dict.get(
1032
+ self.gene_token_id_dict.get(item, np.nan), np.nan
1033
+ )
1034
+ if isinstance(item, tuple):
1035
+ return tuple(
1036
+ [
1037
+ self.gene_id_name_dict.get(
1038
+ self.gene_token_id_dict.get(i, np.nan), np.nan
1039
+ )
1040
+ for i in item
1041
+ ]
1042
+ )
geneformer/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5e33a757431643b3697de7ef6127950cdc49e06e58d4266b3a3ab191b683f14
3
+ size 41183536
geneformer/perturber_utils.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools as it
2
+ import logging
3
+ import pickle
4
+ import re
5
+ from collections import defaultdict
6
+ from typing import List
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import seaborn as sns
11
+ import torch
12
+ from datasets import Dataset, load_from_disk
13
+ from transformers import (
14
+ BertForMaskedLM,
15
+ BertForSequenceClassification,
16
+ BertForTokenClassification,
17
+ )
18
+
19
+ sns.set()
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ # load data and filter by defined criteria
25
+ def load_and_filter(filter_data, nproc, input_data_file):
26
+ data = load_from_disk(input_data_file)
27
+ if filter_data is not None:
28
+ data = filter_by_dict(data, filter_data, nproc)
29
+ return data
30
+
31
+
32
+ def filter_by_dict(data, filter_data, nproc):
33
+ for key, value in filter_data.items():
34
+
35
+ def filter_data_by_criteria(example):
36
+ return example[key] in value
37
+
38
+ data = data.filter(filter_data_by_criteria, num_proc=nproc)
39
+ if len(data) == 0:
40
+ logger.error("No cells remain after filtering. Check filtering criteria.")
41
+ raise
42
+ return data
43
+
44
+
45
+ def filter_data_by_tokens(filtered_input_data, tokens, nproc):
46
+ def if_has_tokens(example):
47
+ return len(set(example["input_ids"]).intersection(tokens)) == len(tokens)
48
+
49
+ filtered_input_data = filtered_input_data.filter(if_has_tokens, num_proc=nproc)
50
+ return filtered_input_data
51
+
52
+
53
+ def logging_filtered_data_len(filtered_input_data, filtered_tokens_categ):
54
+ if len(filtered_input_data) == 0:
55
+ logger.error(f"No cells in dataset contain {filtered_tokens_categ}.")
56
+ raise
57
+ else:
58
+ logger.info(f"# cells with {filtered_tokens_categ}: {len(filtered_input_data)}")
59
+
60
+
61
+ def filter_data_by_tokens_and_log(
62
+ filtered_input_data, tokens, nproc, filtered_tokens_categ
63
+ ):
64
+ # filter for cells with anchor gene
65
+ filtered_input_data = filter_data_by_tokens(filtered_input_data, tokens, nproc)
66
+ # logging length of filtered data
67
+ logging_filtered_data_len(filtered_input_data, filtered_tokens_categ)
68
+
69
+ return filtered_input_data
70
+
71
+
72
+ def filter_data_by_start_state(filtered_input_data, cell_states_to_model, nproc):
73
+ # confirm that start state is valid to prevent futile filtering
74
+ state_key = cell_states_to_model["state_key"]
75
+ state_values = filtered_input_data[state_key]
76
+ start_state = cell_states_to_model["start_state"]
77
+ if start_state not in state_values:
78
+ logger.error(
79
+ f"Start state {start_state} is not present "
80
+ f"in the dataset's {state_key} attribute."
81
+ )
82
+ raise
83
+
84
+ # filter for start state cells
85
+ def filter_for_origin(example):
86
+ return example[state_key] in [start_state]
87
+
88
+ filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=nproc)
89
+ return filtered_input_data
90
+
91
+
92
+ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
93
+ if cell_inds_to_perturb["start"] >= len(filtered_input_data):
94
+ logger.error(
95
+ "cell_inds_to_perturb['start'] is larger than the filtered dataset."
96
+ )
97
+ raise
98
+ if cell_inds_to_perturb["end"] > len(filtered_input_data):
99
+ logger.warning(
100
+ "cell_inds_to_perturb['end'] is larger than the filtered dataset. \
101
+ Setting to the end of the filtered dataset."
102
+ )
103
+ cell_inds_to_perturb["end"] = len(filtered_input_data)
104
+ filtered_input_data = filtered_input_data.select(
105
+ [i for i in range(cell_inds_to_perturb["start"], cell_inds_to_perturb["end"])]
106
+ )
107
+ return filtered_input_data
108
+
109
+
110
+ # load model to GPU
111
+ def load_model(model_type, num_classes, model_directory):
112
+ if model_type == "Pretrained":
113
+ model = BertForMaskedLM.from_pretrained(
114
+ model_directory, output_hidden_states=True, output_attentions=False
115
+ )
116
+ elif model_type == "GeneClassifier":
117
+ model = BertForTokenClassification.from_pretrained(
118
+ model_directory,
119
+ num_labels=num_classes,
120
+ output_hidden_states=True,
121
+ output_attentions=False,
122
+ )
123
+ elif model_type == "CellClassifier":
124
+ model = BertForSequenceClassification.from_pretrained(
125
+ model_directory,
126
+ num_labels=num_classes,
127
+ output_hidden_states=True,
128
+ output_attentions=False,
129
+ )
130
+ # put the model in eval mode for fwd pass
131
+ model.eval()
132
+ model = model.to("cuda:0")
133
+ return model
134
+
135
+
136
+ def quant_layers(model):
137
+ layer_nums = []
138
+ for name, parameter in model.named_parameters():
139
+ if "layer" in name:
140
+ layer_nums += [int(name.split("layer.")[1].split(".")[0])]
141
+ return int(max(layer_nums)) + 1
142
+
143
+
144
+ def get_model_input_size(model):
145
+ return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
146
+
147
+
148
+ def flatten_list(megalist):
149
+ return [item for sublist in megalist for item in sublist]
150
+
151
+
152
+ def measure_length(example):
153
+ example["length"] = len(example["input_ids"])
154
+ return example
155
+
156
+
157
+ def downsample_and_sort(data, max_ncells):
158
+ num_cells = len(data)
159
+ # if max number of cells is defined, then shuffle and subsample to this max number
160
+ if max_ncells is not None:
161
+ if num_cells > max_ncells:
162
+ data = data.shuffle(seed=42)
163
+ num_cells = max_ncells
164
+ data_subset = data.select([i for i in range(num_cells)])
165
+ # sort dataset with largest cell first to encounter any memory errors earlier
166
+ data_sorted = data_subset.sort("length", reverse=True)
167
+ return data_sorted
168
+
169
+
170
+ def get_possible_states(cell_states_to_model):
171
+ possible_states = []
172
+ for key in ["start_state", "goal_state"]:
173
+ possible_states += [cell_states_to_model[key]]
174
+ possible_states += cell_states_to_model.get("alt_states", [])
175
+ return possible_states
176
+
177
+
178
+ def forward_pass_single_cell(model, example_cell, layer_to_quant):
179
+ example_cell.set_format(type="torch")
180
+ input_data = example_cell["input_ids"]
181
+ with torch.no_grad():
182
+ outputs = model(input_ids=input_data.to("cuda"))
183
+ emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
184
+ del outputs
185
+ return emb
186
+
187
+
188
+ def perturb_emb_by_index(emb, indices):
189
+ mask = torch.ones(emb.numel(), dtype=torch.bool)
190
+ mask[indices] = False
191
+ return emb[mask]
192
+
193
+
194
+ def delete_indices(example):
195
+ indices = example["perturb_index"]
196
+ if any(isinstance(el, list) for el in indices):
197
+ indices = flatten_list(indices)
198
+ for index in sorted(indices, reverse=True):
199
+ del example["input_ids"][index]
200
+
201
+ example["length"] = len(example["input_ids"])
202
+ return example
203
+
204
+
205
+ # for genes_to_perturb = "all" where only genes within cell are overexpressed
206
+ def overexpress_indices(example):
207
+ indices = example["perturb_index"]
208
+ if any(isinstance(el, list) for el in indices):
209
+ indices = flatten_list(indices)
210
+ for index in sorted(indices, reverse=True):
211
+ example["input_ids"].insert(0, example["input_ids"].pop(index))
212
+
213
+ example["length"] = len(example["input_ids"])
214
+ return example
215
+
216
+
217
+ # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
218
+ def overexpress_tokens(example, max_len):
219
+ # -100 indicates tokens to overexpress are not present in rank value encoding
220
+ if example["perturb_index"] != [-100]:
221
+ example = delete_indices(example)
222
+ [
223
+ example["input_ids"].insert(0, token)
224
+ for token in example["tokens_to_perturb"][::-1]
225
+ ]
226
+
227
+ # truncate to max input size, must also truncate original emb to be comparable
228
+ if len(example["input_ids"]) > max_len:
229
+ example["input_ids"] = example["input_ids"][0:max_len]
230
+
231
+ example["length"] = len(example["input_ids"])
232
+ return example
233
+
234
+
235
+ def calc_n_overflow(max_len, example_len, tokens_to_perturb, indices_to_perturb):
236
+ n_to_add = len(tokens_to_perturb) - len(indices_to_perturb)
237
+ n_overflow = example_len + n_to_add - max_len
238
+ return n_overflow
239
+
240
+
241
+ def truncate_by_n_overflow(example):
242
+ new_max_len = example["length"] - example["n_overflow"]
243
+ example["input_ids"] = example["input_ids"][0:new_max_len]
244
+ example["length"] = len(example["input_ids"])
245
+ return example
246
+
247
+
248
+ def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
249
+ # indices_to_remove is list of indices to remove
250
+ indices_to_keep = [
251
+ i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove
252
+ ]
253
+ num_dims = emb.dim()
254
+ emb_slice = [
255
+ slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)
256
+ ]
257
+ sliced_emb = emb[emb_slice]
258
+ return sliced_emb
259
+
260
+
261
+ def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
262
+ output_batch_list = [
263
+ remove_indices_from_emb(emb_batch[i, :, :], idxes, gene_dim - 1)
264
+ for i, idxes in enumerate(list_of_indices_to_remove)
265
+ ]
266
+ # add padding given genes are sometimes added that are or are not in original cell
267
+ batch_max = max([emb.size()[gene_dim - 1] for emb in output_batch_list])
268
+ output_batch_list_padded = [
269
+ pad_xd_tensor(emb, 0.000, batch_max, gene_dim - 1) for emb in output_batch_list
270
+ ]
271
+ return torch.stack(output_batch_list_padded)
272
+
273
+
274
+ # removes perturbed indices
275
+ # need to handle the various cases where a set of genes is overexpressed
276
+ def remove_perturbed_indices_set(
277
+ emb,
278
+ perturb_type: str,
279
+ indices_to_perturb: List[List],
280
+ tokens_to_perturb: List[List],
281
+ original_lengths: List[int],
282
+ input_ids=None,
283
+ ):
284
+ if perturb_type == "overexpress":
285
+ num_perturbed = len(tokens_to_perturb)
286
+ if num_perturbed == 1:
287
+ indices_to_perturb_orig = [
288
+ idx if idx != [-100] else [None] for idx in indices_to_perturb
289
+ ]
290
+ if all(v is [None] for v in indices_to_perturb_orig):
291
+ return emb
292
+ else:
293
+ indices_to_perturb_orig = []
294
+
295
+ for idx_list in indices_to_perturb:
296
+ indices_to_perturb_orig.append(
297
+ [idx if idx != [-100] else [None] for idx in idx_list]
298
+ )
299
+
300
+ else:
301
+ indices_to_perturb_orig = indices_to_perturb
302
+
303
+ emb = remove_indices_from_emb_batch(emb, indices_to_perturb_orig, gene_dim=1)
304
+
305
+ return emb
306
+
307
+
308
+ def make_perturbation_batch(
309
+ example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
310
+ ) -> tuple[Dataset, List[int]]:
311
+ if combo_lvl == 0 and tokens_to_perturb == "all":
312
+ if perturb_type in ["overexpress", "activate"]:
313
+ range_start = 1
314
+ elif perturb_type in ["delete", "inhibit"]:
315
+ range_start = 0
316
+ indices_to_perturb = [
317
+ [i] for i in range(range_start, example_cell["length"][0])
318
+ ]
319
+ # elif combo_lvl > 0 and anchor_token is None:
320
+ ## to implement
321
+ elif combo_lvl > 0 and (anchor_token is not None):
322
+ example_input_ids = example_cell["input_ids"][0]
323
+ anchor_index = example_input_ids.index(anchor_token[0])
324
+ indices_to_perturb = [
325
+ sorted([anchor_index, i]) if i != anchor_index else None
326
+ for i in range(example_cell["length"][0])
327
+ ]
328
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
329
+ else:
330
+ example_input_ids = example_cell["input_ids"][0]
331
+ indices_to_perturb = [
332
+ [example_input_ids.index(token)] if token in example_input_ids else None
333
+ for token in tokens_to_perturb
334
+ ]
335
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
336
+
337
+ # create all permutations of combo_lvl of modifiers from tokens_to_perturb
338
+ if combo_lvl > 0 and (anchor_token is None):
339
+ if tokens_to_perturb != "all":
340
+ if len(tokens_to_perturb) == combo_lvl + 1:
341
+ indices_to_perturb = [
342
+ list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
343
+ ]
344
+ else:
345
+ all_indices = [[i] for i in range(example_cell["length"][0])]
346
+ all_indices = [
347
+ index for index in all_indices if index not in indices_to_perturb
348
+ ]
349
+ indices_to_perturb = [
350
+ [[j for i in indices_to_perturb for j in i], x] for x in all_indices
351
+ ]
352
+
353
+ length = len(indices_to_perturb)
354
+ perturbation_dataset = Dataset.from_dict(
355
+ {
356
+ "input_ids": example_cell["input_ids"] * length,
357
+ "perturb_index": indices_to_perturb,
358
+ }
359
+ )
360
+
361
+ if length < 400:
362
+ num_proc_i = 1
363
+ else:
364
+ num_proc_i = num_proc
365
+
366
+ if perturb_type == "delete":
367
+ perturbation_dataset = perturbation_dataset.map(
368
+ delete_indices, num_proc=num_proc_i
369
+ )
370
+ elif perturb_type == "overexpress":
371
+ perturbation_dataset = perturbation_dataset.map(
372
+ overexpress_indices, num_proc=num_proc_i
373
+ )
374
+
375
+ perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
376
+
377
+ return perturbation_dataset, indices_to_perturb
378
+
379
+
380
+ # perturbed cell emb removing the activated/overexpressed/inhibited gene emb
381
+ # so that only non-perturbed gene embeddings are compared to each other
382
+ # in original or perturbed context
383
+ def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
384
+ all_embs_list = []
385
+
386
+ # if making comparison batch for multiple perturbations in single cell
387
+ if perturb_group is False:
388
+ # squeeze if single cell
389
+ if original_emb_batch.ndim == 3 and original_emb_batch.size()[0] == 1:
390
+ original_emb_batch = torch.squeeze(original_emb_batch)
391
+ original_emb_list = [original_emb_batch] * len(indices_to_perturb)
392
+ # if making comparison batch for single perturbation in multiple cells
393
+ elif perturb_group is True:
394
+ original_emb_list = original_emb_batch
395
+
396
+ for original_emb, indices in zip(original_emb_list, indices_to_perturb):
397
+ if indices == [-100]:
398
+ all_embs_list += [original_emb[:]]
399
+ continue
400
+
401
+ emb_list = []
402
+ start = 0
403
+ if any(isinstance(el, list) for el in indices):
404
+ indices = flatten_list(indices)
405
+
406
+ # removes indices that were perturbed from the original embedding
407
+ for i in sorted(indices):
408
+ emb_list += [original_emb[start:i]]
409
+ start = i + 1
410
+
411
+ emb_list += [original_emb[start:]]
412
+ all_embs_list += [torch.cat(emb_list)]
413
+
414
+ len_set = set([emb.size()[0] for emb in all_embs_list])
415
+ if len(len_set) > 1:
416
+ max_len = max(len_set)
417
+ all_embs_list = [pad_2d_tensor(emb, None, max_len, 0) for emb in all_embs_list]
418
+ return torch.stack(all_embs_list)
419
+
420
+
421
+ def pad_list(input_ids, pad_token_id, max_len):
422
+ input_ids = np.pad(
423
+ input_ids,
424
+ (0, max_len - len(input_ids)),
425
+ mode="constant",
426
+ constant_values=pad_token_id,
427
+ )
428
+ return input_ids
429
+
430
+
431
+ def pad_xd_tensor(tensor, pad_token_id, max_len, dim):
432
+ padding_length = max_len - tensor.size()[dim]
433
+ # Construct a padding configuration where all padding values are 0, except for the padding dimension
434
+ # 2 * number of dimensions (padding before and after for every dimension)
435
+ pad_config = [0] * 2 * tensor.dim()
436
+ # Set the padding after the desired dimension to the calculated padding length
437
+ pad_config[-2 * dim - 1] = padding_length
438
+ return torch.nn.functional.pad(
439
+ tensor, pad=pad_config, mode="constant", value=pad_token_id
440
+ )
441
+
442
+
443
+ def pad_tensor(tensor, pad_token_id, max_len):
444
+ tensor = torch.nn.functional.pad(
445
+ tensor, pad=(0, max_len - tensor.numel()), mode="constant", value=pad_token_id
446
+ )
447
+
448
+ return tensor
449
+
450
+
451
+ def pad_2d_tensor(tensor, pad_token_id, max_len, dim):
452
+ if dim == 0:
453
+ pad = (0, 0, 0, max_len - tensor.size()[dim])
454
+ elif dim == 1:
455
+ pad = (0, max_len - tensor.size()[dim], 0, 0)
456
+ tensor = torch.nn.functional.pad(
457
+ tensor, pad=pad, mode="constant", value=pad_token_id
458
+ )
459
+ return tensor
460
+
461
+
462
+ def pad_3d_tensor(tensor, pad_token_id, max_len, dim):
463
+ if dim == 0:
464
+ raise Exception("dim 0 usually does not need to be padded.")
465
+ if dim == 1:
466
+ pad = (0, 0, 0, max_len - tensor.size()[dim])
467
+ elif dim == 2:
468
+ pad = (0, max_len - tensor.size()[dim], 0, 0)
469
+ tensor = torch.nn.functional.pad(
470
+ tensor, pad=pad, mode="constant", value=pad_token_id
471
+ )
472
+ return tensor
473
+
474
+
475
+ def pad_or_truncate_encoding(encoding, pad_token_id, max_len):
476
+ if isinstance(encoding, torch.Tensor):
477
+ encoding_len = encoding.size()[0]
478
+ elif isinstance(encoding, list):
479
+ encoding_len = len(encoding)
480
+ if encoding_len > max_len:
481
+ encoding = encoding[0:max_len]
482
+ elif encoding_len < max_len:
483
+ if isinstance(encoding, torch.Tensor):
484
+ encoding = pad_tensor(encoding, pad_token_id, max_len)
485
+ elif isinstance(encoding, list):
486
+ encoding = pad_list(encoding, pad_token_id, max_len)
487
+ return encoding
488
+
489
+
490
+ # pad list of tensors and convert to tensor
491
+ def pad_tensor_list(
492
+ tensor_list,
493
+ dynamic_or_constant,
494
+ pad_token_id,
495
+ model_input_size,
496
+ dim=None,
497
+ padding_func=None,
498
+ ):
499
+ # determine maximum tensor length
500
+ if dynamic_or_constant == "dynamic":
501
+ max_len = max([tensor.squeeze().numel() for tensor in tensor_list])
502
+ elif isinstance(dynamic_or_constant, int):
503
+ max_len = dynamic_or_constant
504
+ else:
505
+ max_len = model_input_size
506
+ logger.warning(
507
+ "If padding style is constant, must provide integer value. "
508
+ f"Setting padding to max input size {model_input_size}."
509
+ )
510
+
511
+ # pad all tensors to maximum length
512
+ if dim is None:
513
+ tensor_list = [
514
+ pad_tensor(tensor, pad_token_id, max_len) for tensor in tensor_list
515
+ ]
516
+ else:
517
+ tensor_list = [
518
+ padding_func(tensor, pad_token_id, max_len, dim) for tensor in tensor_list
519
+ ]
520
+ # return stacked tensors
521
+ if padding_func != pad_3d_tensor:
522
+ return torch.stack(tensor_list)
523
+ else:
524
+ return torch.cat(tensor_list, 0)
525
+
526
+
527
+ def gen_attention_mask(minibatch_encoding, max_len=None):
528
+ if max_len is None:
529
+ max_len = max(minibatch_encoding["length"])
530
+ original_lens = minibatch_encoding["length"]
531
+ attention_mask = [
532
+ [1] * original_len + [0] * (max_len - original_len)
533
+ if original_len <= max_len
534
+ else [1] * max_len
535
+ for original_len in original_lens
536
+ ]
537
+ return torch.tensor(attention_mask, device="cuda")
538
+
539
+
540
+ # get cell embeddings excluding padding
541
+ def mean_nonpadding_embs(embs, original_lens, dim=1):
542
+ # create a mask tensor based on padding lengths
543
+ mask = torch.arange(embs.size(dim), device=embs.device) < original_lens.unsqueeze(1)
544
+ if embs.dim() == 3:
545
+ # fill the masked positions in embs with zeros
546
+ masked_embs = embs.masked_fill(~mask.unsqueeze(2), 0.0)
547
+
548
+ # compute the mean across the non-padding dimensions
549
+ mean_embs = masked_embs.sum(dim) / original_lens.view(-1, 1).float()
550
+
551
+ elif embs.dim() == 2:
552
+ masked_embs = embs.masked_fill(~mask, 0.0)
553
+ mean_embs = masked_embs.sum(dim) / original_lens.float()
554
+ return mean_embs
555
+
556
+
557
+ # get cell embeddings when there is no padding
558
+ def compute_nonpadded_cell_embedding(embs, cell_emb_style):
559
+ if cell_emb_style == "mean_pool":
560
+ return torch.mean(embs, dim=embs.ndim - 2)
561
+
562
+
563
+ # quantify shifts for a set of genes
564
+ def quant_cos_sims(
565
+ perturbation_emb,
566
+ original_emb,
567
+ cell_states_to_model,
568
+ state_embs_dict,
569
+ emb_mode="gene",
570
+ ):
571
+ if emb_mode == "gene":
572
+ cos = torch.nn.CosineSimilarity(dim=2)
573
+ elif emb_mode == "cell":
574
+ cos = torch.nn.CosineSimilarity(dim=1)
575
+
576
+ if cell_states_to_model is None:
577
+ cos_sims = cos(perturbation_emb, original_emb).to("cuda")
578
+ else:
579
+ possible_states = get_possible_states(cell_states_to_model)
580
+ cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
581
+ for state in possible_states:
582
+ cos_sims[state] = cos_sim_shift(
583
+ original_emb,
584
+ perturbation_emb,
585
+ state_embs_dict[state].to("cuda"), # required to move to cuda here
586
+ cos,
587
+ )
588
+
589
+ return cos_sims
590
+
591
+
592
+ # calculate cos sim shift of perturbation with respect to origin and alternative cell
593
+ def cos_sim_shift(original_emb, perturbed_emb, end_emb, cos):
594
+ origin_v_end = cos(original_emb, end_emb)
595
+ perturb_v_end = cos(perturbed_emb, end_emb)
596
+
597
+ return perturb_v_end - origin_v_end
598
+
599
+
600
+ def concatenate_cos_sims(cos_sims):
601
+ if isinstance(cos_sims, list):
602
+ return torch.cat(cos_sims)
603
+ else:
604
+ for state in cos_sims.keys():
605
+ cos_sims[state] = torch.cat(cos_sims[state])
606
+ return cos_sims
607
+
608
+
609
+ def write_perturbation_dictionary(cos_sims_dict: defaultdict, output_path_prefix: str):
610
+ with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
611
+ pickle.dump(cos_sims_dict, fp)
612
+
613
+
614
+ def tensor_list_to_pd(tensor_list):
615
+ tensor = torch.cat(tensor_list).cpu().numpy()
616
+ df = pd.DataFrame(tensor)
617
+ return df
618
+
619
+
620
+ def validate_cell_states_to_model(cell_states_to_model):
621
+ if cell_states_to_model is not None:
622
+ if len(cell_states_to_model.items()) == 1:
623
+ logger.warning(
624
+ "The single value dictionary for cell_states_to_model will be "
625
+ "replaced with a dictionary with named keys for start, goal, and alternate states. "
626
+ "Please specify state_key, start_state, goal_state, and alt_states "
627
+ "in the cell_states_to_model dictionary for future use. "
628
+ "For example, cell_states_to_model={"
629
+ "'state_key': 'disease', "
630
+ "'start_state': 'dcm', "
631
+ "'goal_state': 'nf', "
632
+ "'alt_states': ['hcm', 'other1', 'other2']}"
633
+ )
634
+ for key, value in cell_states_to_model.items():
635
+ if (len(value) == 3) and isinstance(value, tuple):
636
+ if (
637
+ isinstance(value[0], list)
638
+ and isinstance(value[1], list)
639
+ and isinstance(value[2], list)
640
+ ):
641
+ if len(value[0]) == 1 and len(value[1]) == 1:
642
+ all_values = value[0] + value[1] + value[2]
643
+ if len(all_values) == len(set(all_values)):
644
+ continue
645
+ # reformat to the new named key format
646
+ state_values = flatten_list(list(cell_states_to_model.values()))
647
+
648
+ cell_states_to_model = {
649
+ "state_key": list(cell_states_to_model.keys())[0],
650
+ "start_state": state_values[0][0],
651
+ "goal_state": state_values[1][0],
652
+ "alt_states": state_values[2:][0],
653
+ }
654
+ elif set(cell_states_to_model.keys()).issuperset(
655
+ {"state_key", "start_state", "goal_state"}
656
+ ):
657
+ if (
658
+ (cell_states_to_model["state_key"] is None)
659
+ or (cell_states_to_model["start_state"] is None)
660
+ or (cell_states_to_model["goal_state"] is None)
661
+ ):
662
+ logger.error(
663
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
664
+ )
665
+ raise
666
+
667
+ if (
668
+ cell_states_to_model["start_state"]
669
+ == cell_states_to_model["goal_state"]
670
+ ):
671
+ logger.error("All states must be unique.")
672
+ raise
673
+
674
+ if "alt_states" in set(cell_states_to_model.keys()):
675
+ if cell_states_to_model["alt_states"] is not None:
676
+ if not isinstance(cell_states_to_model["alt_states"], list):
677
+ logger.error(
678
+ "cell_states_to_model['alt_states'] must be a list (even if it is one element)."
679
+ )
680
+ raise
681
+ if len(cell_states_to_model["alt_states"]) != len(
682
+ set(cell_states_to_model["alt_states"])
683
+ ):
684
+ logger.error("All states must be unique.")
685
+ raise
686
+ else:
687
+ cell_states_to_model["alt_states"] = []
688
+
689
+ else:
690
+ logger.error(
691
+ "cell_states_to_model must only have the following four keys: "
692
+ "'state_key', 'start_state', 'goal_state', 'alt_states'."
693
+ "For example, cell_states_to_model={"
694
+ "'state_key': 'disease', "
695
+ "'start_state': 'dcm', "
696
+ "'goal_state': 'nf', "
697
+ "'alt_states': ['hcm', 'other1', 'other2']}"
698
+ )
699
+ raise
geneformer/pretrainer.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer precollator and pretrainer.
3
+
4
+ Huggingface data collator and trainer modified to accommodate single-cell transcriptomics data.
5
+ """
6
+ import collections
7
+ import math
8
+ import pickle
9
+ import warnings
10
+ from enum import Enum
11
+ from typing import Dict, Iterator, List, Optional, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ from datasets import Dataset
16
+ from packaging import version
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from torch.utils.data.sampler import RandomSampler
19
+ from transformers import (
20
+ BatchEncoding,
21
+ DataCollatorForLanguageModeling,
22
+ SpecialTokensMixin,
23
+ Trainer,
24
+ )
25
+ from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
26
+ from transformers.trainer_pt_utils import (
27
+ DistributedLengthGroupedSampler,
28
+ DistributedSamplerWithLoop,
29
+ LengthGroupedSampler,
30
+ )
31
+ from transformers.training_args import ParallelMode
32
+ from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
+ from transformers.utils.generic import _is_tensorflow, _is_torch
34
+
35
+ from .tokenizer import TOKEN_DICTIONARY_FILE
36
+
37
+ logger = logging.get_logger(__name__)
38
+ EncodedInput = List[int]
39
+ VERY_LARGE_INTEGER = int(
40
+ 1e30
41
+ ) # This is used to set the max input length for a model with infinite size input
42
+ LARGE_INTEGER = int(
43
+ 1e20
44
+ ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
45
+
46
+ if is_sagemaker_dp_enabled():
47
+ import smdistributed.dataparallel.torch.distributed as dist
48
+ else:
49
+ import torch.distributed as dist
50
+
51
+ _is_torch_generator_available = False
52
+ if version.parse(torch.__version__) >= version.parse("1.6"):
53
+ _is_torch_generator_available = True
54
+
55
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
56
+ token_dictionary = pickle.load(f)
57
+
58
+
59
+ class ExplicitEnum(Enum):
60
+ """
61
+ Enum with more explicit error message for missing values.
62
+ """
63
+
64
+ @classmethod
65
+ def _missing_(cls, value):
66
+ raise ValueError(
67
+ "%r is not a valid %s, please select one of %s"
68
+ % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
69
+ )
70
+
71
+
72
+ class TruncationStrategy(ExplicitEnum):
73
+ """
74
+ Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
75
+ tab-completion in an IDE.
76
+ """
77
+
78
+ ONLY_FIRST = "only_first"
79
+ ONLY_SECOND = "only_second"
80
+ LONGEST_FIRST = "longest_first"
81
+ DO_NOT_TRUNCATE = "do_not_truncate"
82
+
83
+
84
+ class PaddingStrategy(ExplicitEnum):
85
+ """
86
+ Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
87
+ in an IDE.
88
+ """
89
+
90
+ LONGEST = "longest"
91
+ MAX_LENGTH = "max_length"
92
+ DO_NOT_PAD = "do_not_pad"
93
+
94
+
95
+ class TensorType(ExplicitEnum):
96
+ """
97
+ Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
98
+ tab-completion in an IDE.
99
+ """
100
+
101
+ PYTORCH = "pt"
102
+ TENSORFLOW = "tf"
103
+ NUMPY = "np"
104
+ JAX = "jax"
105
+
106
+
107
+ class GeneformerPreCollator(SpecialTokensMixin):
108
+ def __init__(self, *args, **kwargs) -> None:
109
+
110
+ super().__init__(mask_token = "<mask>", pad_token = "<pad>")
111
+
112
+ self.token_dictionary = kwargs.get("token_dictionary")
113
+ # self.mask_token = "<mask>"
114
+ # self.mask_token_id = self.token_dictionary.get("<mask>")
115
+ # self.pad_token = "<pad>"
116
+ # self.pad_token_id = self.token_dictionary.get("<pad>")
117
+ self.padding_side = "right"
118
+ # self.all_special_ids = [
119
+ # self.token_dictionary.get("<mask>"),
120
+ # self.token_dictionary.get("<pad>"),
121
+ # ]
122
+ self.model_input_names = ["input_ids"]
123
+
124
+ def convert_ids_to_tokens(self,value):
125
+ return self.token_dictionary.get(value)
126
+
127
+ def _get_padding_truncation_strategies(
128
+ self,
129
+ padding=False,
130
+ truncation=False,
131
+ max_length=None,
132
+ pad_to_multiple_of=None,
133
+ verbose=True,
134
+ **kwargs,
135
+ ):
136
+ """
137
+ Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
138
+ and pad_to_max_length) and behaviors.
139
+ """
140
+ old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
141
+ old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
142
+
143
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
144
+ # If you only set max_length, it activates truncation for max_length
145
+ if max_length is not None and padding is False and truncation is False:
146
+ if verbose:
147
+ if not self.deprecation_warnings.get(
148
+ "Truncation-not-explicitly-activated", False
149
+ ):
150
+ logger.warning(
151
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, "
152
+ "please use `truncation=True` to explicitly truncate examples to max length. "
153
+ "Defaulting to 'longest_first' truncation strategy. "
154
+ "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
155
+ "more precisely by providing a specific strategy to `truncation`."
156
+ )
157
+ self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
158
+ truncation = "longest_first"
159
+
160
+ # Get padding strategy
161
+ if padding is False and old_pad_to_max_length:
162
+ if verbose:
163
+ warnings.warn(
164
+ "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
165
+ "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
166
+ "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
167
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
168
+ "maximal input size of the model (e.g. 512 for Bert).",
169
+ FutureWarning,
170
+ )
171
+ if max_length is None:
172
+ padding_strategy = PaddingStrategy.LONGEST
173
+ else:
174
+ padding_strategy = PaddingStrategy.MAX_LENGTH
175
+ elif padding is not False:
176
+ if padding is True:
177
+ padding_strategy = (
178
+ PaddingStrategy.LONGEST
179
+ ) # Default to pad to the longest sequence in the batch
180
+ elif not isinstance(padding, PaddingStrategy):
181
+ padding_strategy = PaddingStrategy(padding)
182
+ elif isinstance(padding, PaddingStrategy):
183
+ padding_strategy = padding
184
+ else:
185
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
186
+
187
+ # Get truncation strategy
188
+ if truncation is False and old_truncation_strategy != "do_not_truncate":
189
+ if verbose:
190
+ warnings.warn(
191
+ "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
192
+ "use `truncation=True` to truncate examples to a max length. You can give a specific "
193
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
194
+ "maximal input size of the model (e.g. 512 for Bert). "
195
+ " If you have pairs of inputs, you can give a specific truncation strategy selected among "
196
+ "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
197
+ "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
198
+ "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
199
+ FutureWarning,
200
+ )
201
+ truncation_strategy = TruncationStrategy(old_truncation_strategy)
202
+ elif truncation is not False:
203
+ if truncation is True:
204
+ truncation_strategy = (
205
+ TruncationStrategy.LONGEST_FIRST
206
+ ) # Default to truncate the longest sequences in pairs of inputs
207
+ elif not isinstance(truncation, TruncationStrategy):
208
+ truncation_strategy = TruncationStrategy(truncation)
209
+ elif isinstance(truncation, TruncationStrategy):
210
+ truncation_strategy = truncation
211
+ else:
212
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
213
+
214
+ # Set max length if needed
215
+ if max_length is None:
216
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
217
+ if self.model_max_length > LARGE_INTEGER:
218
+ if verbose:
219
+ if not self.deprecation_warnings.get(
220
+ "Asking-to-pad-to-max_length", False
221
+ ):
222
+ logger.warning(
223
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
224
+ "Default to no padding."
225
+ )
226
+ self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
227
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
228
+ else:
229
+ max_length = self.model_max_length
230
+
231
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
232
+ if self.model_max_length > LARGE_INTEGER:
233
+ if verbose:
234
+ if not self.deprecation_warnings.get(
235
+ "Asking-to-truncate-to-max_length", False
236
+ ):
237
+ logger.warning(
238
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
239
+ "Default to no truncation."
240
+ )
241
+ self.deprecation_warnings[
242
+ "Asking-to-truncate-to-max_length"
243
+ ] = True
244
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
245
+ else:
246
+ max_length = self.model_max_length
247
+
248
+ # Test if we have a padding token
249
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
250
+ not self.pad_token or self.pad_token_id < 0
251
+ ):
252
+ raise ValueError(
253
+ "Asking to pad but the tokenizer does not have a padding token. "
254
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
255
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
256
+ )
257
+
258
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
259
+ if (
260
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
261
+ and padding_strategy != PaddingStrategy.DO_NOT_PAD
262
+ and pad_to_multiple_of is not None
263
+ and max_length is not None
264
+ and (max_length % pad_to_multiple_of != 0)
265
+ ):
266
+ raise ValueError(
267
+ f"Truncation and padding are both activated but "
268
+ f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
269
+ )
270
+
271
+ return padding_strategy, truncation_strategy, max_length, kwargs
272
+
273
+ def pad(
274
+ self,
275
+ encoded_inputs: Union[
276
+ BatchEncoding,
277
+ List[BatchEncoding],
278
+ Dict[str, EncodedInput],
279
+ Dict[str, List[EncodedInput]],
280
+ List[Dict[str, EncodedInput]],
281
+ ],
282
+ padding: Union[bool, str, PaddingStrategy] = True,
283
+ max_length: Optional[int] = None,
284
+ pad_to_multiple_of: Optional[int] = None,
285
+ return_attention_mask: Optional[bool] = True,
286
+ return_tensors: Optional[Union[str, TensorType]] = None,
287
+ verbose: bool = True,
288
+ ) -> BatchEncoding:
289
+ """
290
+ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
291
+ in the batch.
292
+
293
+ Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
294
+ ``self.pad_token_id`` and ``self.pad_token_type_id``)
295
+
296
+ .. note::
297
+
298
+ If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
299
+ result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
300
+ case of PyTorch tensors, you will lose the specific device of your tensors however.
301
+
302
+ Args:
303
+ encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
304
+ Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
305
+ List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
306
+ List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
307
+ well as in a PyTorch Dataloader collate function.
308
+
309
+ Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
310
+ see the note above for the return type.
311
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
312
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
313
+ index) among:
314
+
315
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
316
+ single sequence if provided).
317
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
318
+ maximum acceptable input length for the model if that argument is not provided.
319
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
320
+ different lengths).
321
+ max_length (:obj:`int`, `optional`):
322
+ Maximum length of the returned list and optionally padding length (see above).
323
+ pad_to_multiple_of (:obj:`int`, `optional`):
324
+ If set will pad the sequence to a multiple of the provided value.
325
+
326
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
327
+ >= 7.5 (Volta).
328
+ return_attention_mask (:obj:`bool`, `optional`):
329
+ Whether to return the attention mask. If left to the default, will return the attention mask according
330
+ to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
331
+
332
+ `What are attention masks? <../glossary.html#attention-mask>`__
333
+ return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
334
+ If set, will return tensors instead of list of python integers. Acceptable values are:
335
+
336
+ * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
337
+ * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
338
+ * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
339
+ verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
340
+ Whether or not to print more information and warnings.
341
+ """
342
+ # If we have a list of dicts, let's convert it in a dict of lists
343
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
344
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(
345
+ encoded_inputs[0], (dict, BatchEncoding)
346
+ ):
347
+ encoded_inputs = {
348
+ key: [example[key] for example in encoded_inputs]
349
+ for key in encoded_inputs[0].keys()
350
+ }
351
+
352
+ # The model's main input name, usually `input_ids`, has be passed for padding
353
+ if self.model_input_names[0] not in encoded_inputs:
354
+ raise ValueError(
355
+ "You should supply an encoding or a list of encodings to this method"
356
+ f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
357
+ )
358
+
359
+ required_input = encoded_inputs[self.model_input_names[0]]
360
+
361
+ if not required_input:
362
+ if return_attention_mask:
363
+ encoded_inputs["attention_mask"] = []
364
+ return encoded_inputs
365
+
366
+ # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
367
+ # and rebuild them afterwards if no return_tensors is specified
368
+ # Note that we lose the specific device the tensor may be on for PyTorch
369
+
370
+ first_element = required_input[0]
371
+ if isinstance(first_element, (list, tuple)):
372
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
373
+ index = 0
374
+ while len(required_input[index]) == 0:
375
+ index += 1
376
+ if index < len(required_input):
377
+ first_element = required_input[index][0]
378
+ # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
379
+ if not isinstance(first_element, (int, list, tuple)):
380
+ if is_tf_available() and _is_tensorflow(first_element):
381
+ return_tensors = "tf" if return_tensors is None else return_tensors
382
+ elif is_torch_available() and _is_torch(first_element):
383
+ return_tensors = "pt" if return_tensors is None else return_tensors
384
+ elif isinstance(first_element, np.ndarray):
385
+ return_tensors = "np" if return_tensors is None else return_tensors
386
+ else:
387
+ raise ValueError(
388
+ f"type of {first_element} unknown: {type(first_element)}. "
389
+ f"Should be one of a python, numpy, pytorch or tensorflow object."
390
+ )
391
+
392
+ for key, value in encoded_inputs.items():
393
+ encoded_inputs[key] = to_py_obj(value)
394
+
395
+
396
+ # Convert padding_strategy in PaddingStrategy
397
+ padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
398
+ padding=padding, max_length=max_length, verbose=verbose
399
+ )
400
+
401
+ required_input = encoded_inputs[self.model_input_names[0]]
402
+ if required_input and not isinstance(required_input[0], (list, tuple)):
403
+ encoded_inputs = self._pad(
404
+ encoded_inputs,
405
+ max_length=max_length,
406
+ padding_strategy=padding_strategy,
407
+ pad_to_multiple_of=pad_to_multiple_of,
408
+ return_attention_mask=return_attention_mask,
409
+ )
410
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
411
+
412
+ batch_size = len(required_input)
413
+ assert all(
414
+ len(v) == batch_size for v in encoded_inputs.values()
415
+ ), "Some items in the output dictionary have a different batch size than others."
416
+
417
+ if padding_strategy == PaddingStrategy.LONGEST:
418
+ max_length = max(len(inputs) for inputs in required_input)
419
+ padding_strategy = PaddingStrategy.MAX_LENGTH
420
+
421
+ batch_outputs = {}
422
+ for i in range(batch_size):
423
+ inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
424
+ outputs = self._pad(
425
+ inputs,
426
+ max_length=max_length,
427
+ padding_strategy=padding_strategy,
428
+ pad_to_multiple_of=pad_to_multiple_of,
429
+ return_attention_mask=return_attention_mask,
430
+ )
431
+
432
+ for key, value in outputs.items():
433
+ if key not in batch_outputs:
434
+ batch_outputs[key] = []
435
+ batch_outputs[key].append(value)
436
+
437
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
438
+
439
+ def _pad(
440
+ self,
441
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
442
+ max_length: Optional[int] = None,
443
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
444
+ pad_to_multiple_of: Optional[int] = None,
445
+ return_attention_mask: Optional[bool] = None,
446
+ ) -> dict:
447
+ """
448
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
449
+
450
+ Args:
451
+ encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
452
+ max_length: maximum length of the returned list and optionally padding length (see below).
453
+ Will truncate by taking into account the special tokens.
454
+ padding_strategy: PaddingStrategy to use for padding.
455
+
456
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
457
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
458
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
459
+ The tokenizer padding sides are defined in self.padding_side:
460
+
461
+ - 'left': pads on the left of the sequences
462
+ - 'right': pads on the right of the sequences
463
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
464
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
465
+ >= 7.5 (Volta).
466
+ return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
467
+ """
468
+ # Load from model defaults
469
+ if return_attention_mask is None:
470
+ return_attention_mask = "attention_mask" in self.model_input_names
471
+
472
+ required_input = encoded_inputs[self.model_input_names[0]]
473
+
474
+ if padding_strategy == PaddingStrategy.LONGEST:
475
+ max_length = len(required_input)
476
+
477
+ if (
478
+ max_length is not None
479
+ and pad_to_multiple_of is not None
480
+ and (max_length % pad_to_multiple_of != 0)
481
+ ):
482
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
483
+
484
+ needs_to_be_padded = (
485
+ padding_strategy != PaddingStrategy.DO_NOT_PAD
486
+ and len(required_input) != max_length
487
+ )
488
+
489
+ if needs_to_be_padded:
490
+ difference = max_length - len(required_input)
491
+ if self.padding_side == "right":
492
+ if return_attention_mask:
493
+ encoded_inputs["attention_mask"] = [1] * len(required_input) + [
494
+ 0
495
+ ] * difference
496
+ if "token_type_ids" in encoded_inputs:
497
+ encoded_inputs["token_type_ids"] = (
498
+ encoded_inputs["token_type_ids"]
499
+ + [self.pad_token_type_id] * difference
500
+ )
501
+ if "special_tokens_mask" in encoded_inputs:
502
+ encoded_inputs["special_tokens_mask"] = (
503
+ encoded_inputs["special_tokens_mask"] + [1] * difference
504
+ )
505
+ encoded_inputs[self.model_input_names[0]] = (
506
+ required_input + [self.pad_token_id] * difference
507
+ )
508
+ elif self.padding_side == "left":
509
+ if return_attention_mask:
510
+ encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
511
+ required_input
512
+ )
513
+ if "token_type_ids" in encoded_inputs:
514
+ encoded_inputs["token_type_ids"] = [
515
+ self.pad_token_type_id
516
+ ] * difference + encoded_inputs["token_type_ids"]
517
+ if "special_tokens_mask" in encoded_inputs:
518
+ encoded_inputs["special_tokens_mask"] = [
519
+ 1
520
+ ] * difference + encoded_inputs["special_tokens_mask"]
521
+ encoded_inputs[self.model_input_names[0]] = [
522
+ self.pad_token_id
523
+ ] * difference + required_input
524
+ else:
525
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
526
+ elif return_attention_mask and "attention_mask" not in encoded_inputs:
527
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
528
+
529
+ return encoded_inputs
530
+
531
+ def get_special_tokens_mask(
532
+ self,
533
+ token_ids_0: List[int],
534
+ token_ids_1: Optional[List[int]] = None,
535
+ already_has_special_tokens: bool = False,
536
+ ) -> List[int]:
537
+ """
538
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
539
+ special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
540
+ Args:
541
+ token_ids_0 (:obj:`List[int]`):
542
+ List of ids of the first sequence.
543
+ token_ids_1 (:obj:`List[int]`, `optional`):
544
+ List of ids of the second sequence.
545
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
546
+ Whether or not the token list is already formatted with special tokens for the model.
547
+ Returns:
548
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
549
+ """
550
+ assert already_has_special_tokens and token_ids_1 is None, (
551
+ "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
552
+ "Please use a slow (full python) tokenizer to activate this argument."
553
+ "Or set `return_special_tokens_mask=True` when calling the encoding method "
554
+ "to get the special tokens mask in any tokenizer. "
555
+ )
556
+
557
+ all_special_ids = self.all_special_ids # cache the property
558
+
559
+ special_tokens_mask = [
560
+ 1 if token in all_special_ids else 0 for token in token_ids_0
561
+ ]
562
+
563
+ return special_tokens_mask
564
+
565
+ def convert_tokens_to_ids(
566
+ self, tokens: Union[str, List[str]]
567
+ ) -> Union[int, List[int]]:
568
+ """
569
+ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
570
+ vocabulary.
571
+ Args:
572
+ tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
573
+ Returns:
574
+ :obj:`int` or :obj:`List[int]`: The token id or list of token ids.
575
+ """
576
+ if tokens is None:
577
+ return None
578
+
579
+ if isinstance(tokens, str):
580
+ return self._convert_token_to_id_with_added_voc(tokens)
581
+
582
+ ids = []
583
+ for token in tokens:
584
+ ids.append(self._convert_token_to_id_with_added_voc(token))
585
+ return ids
586
+
587
+ def _convert_token_to_id_with_added_voc(self, token):
588
+ if token is None:
589
+ return None
590
+
591
+ return self.token_dictionary.get(token)
592
+
593
+ def __len__(self):
594
+ return len(self.token_dictionary)
595
+
596
+
597
+ class GeneformerPretrainer(Trainer):
598
+ def __init__(self, *args, **kwargs):
599
+ data_collator = kwargs.get("data_collator",None)
600
+ token_dictionary = kwargs.pop("token_dictionary")
601
+
602
+ if data_collator is None:
603
+ precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
604
+
605
+ # # Data Collator Functions
606
+ data_collator = DataCollatorForLanguageModeling(
607
+ tokenizer=precollator, mlm=True, mlm_probability=0.15
608
+ )
609
+ kwargs["data_collator"] = data_collator
610
+
611
+ # load previously saved length vector for dataset to speed up LengthGroupedSampler
612
+ # pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
613
+ example_lengths_file = kwargs.pop("example_lengths_file")
614
+ if example_lengths_file:
615
+ with open(example_lengths_file, "rb") as f:
616
+ self.example_lengths = pickle.load(f)
617
+ else:
618
+ raise Exception(
619
+ "example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
620
+ )
621
+ super().__init__(*args, **kwargs)
622
+ # self.exp_logits_dir = exp_logits_dir
623
+ # self.min_exp_logits = float('inf')
624
+ # self.max_exp_logits = float('-inf')
625
+
626
+ # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
627
+ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
628
+ if not isinstance(self.train_dataset, collections.abc.Sized):
629
+ return None
630
+
631
+ generator = None
632
+ if self.args.world_size <= 1 and _is_torch_generator_available:
633
+ generator = torch.Generator()
634
+ generator.manual_seed(
635
+ int(torch.empty((), dtype=torch.int64).random_().item())
636
+ )
637
+
638
+ # Build the sampler.
639
+ if self.args.group_by_length:
640
+ if is_datasets_available() and isinstance(self.train_dataset, Dataset):
641
+ lengths = self.example_lengths
642
+ else:
643
+ lengths = None
644
+ model_input_name = (
645
+ self.tokenizer.model_input_names[0]
646
+ if self.tokenizer is not None
647
+ else None
648
+ )
649
+ if self.args.world_size <= 1:
650
+ return LengthGroupedSampler(
651
+ dataset=self.train_dataset,
652
+ batch_size=self.args.train_batch_size,
653
+ lengths=lengths,
654
+ model_input_name=model_input_name,
655
+ generator=generator,
656
+ )
657
+ else:
658
+ return CustomDistributedLengthGroupedSampler(
659
+ dataset=self.train_dataset,
660
+ batch_size=self.args.train_batch_size,
661
+ num_replicas=self.args.world_size,
662
+ rank=self.args.process_index,
663
+ lengths=lengths,
664
+ model_input_name=model_input_name,
665
+ seed=self.args.seed,
666
+ )
667
+
668
+ else:
669
+ if self.args.world_size <= 1:
670
+ if _is_torch_generator_available:
671
+ return RandomSampler(self.train_dataset, generator=generator)
672
+ return RandomSampler(self.train_dataset)
673
+ elif (
674
+ self.args.parallel_mode
675
+ in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
676
+ and not self.args.dataloader_drop_last
677
+ ):
678
+ # Use a loop for TPUs when drop_last is False to have all batches have the same size.
679
+ return DistributedSamplerWithLoop(
680
+ self.train_dataset,
681
+ batch_size=self.args.per_device_train_batch_size,
682
+ num_replicas=self.args.world_size,
683
+ rank=self.args.process_index,
684
+ seed=self.args.seed,
685
+ )
686
+ else:
687
+ return DistributedSampler(
688
+ self.train_dataset,
689
+ num_replicas=self.args.world_size,
690
+ rank=self.args.process_index,
691
+ seed=self.args.seed,
692
+ )
693
+
694
+
695
+ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
696
+ r"""
697
+ Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
698
+ length while keeping a bit of randomness.
699
+ """
700
+ # Copied and adapted from PyTorch DistributedSampler.
701
+ def __init__(
702
+ self,
703
+ dataset: Dataset,
704
+ batch_size: int,
705
+ num_replicas: Optional[int] = None,
706
+ rank: Optional[int] = None,
707
+ seed: int = 0,
708
+ drop_last: bool = False,
709
+ lengths: Optional[List[int]] = None,
710
+ model_input_name: Optional[str] = None,
711
+ ):
712
+ if num_replicas is None:
713
+ if not dist.is_available():
714
+ raise RuntimeError("Requires distributed package to be available")
715
+ num_replicas = dist.get_world_size()
716
+ if rank is None:
717
+ if not dist.is_available():
718
+ raise RuntimeError("Requires distributed package to be available")
719
+ rank = dist.get_rank()
720
+ self.dataset = dataset
721
+ self.batch_size = batch_size
722
+ self.num_replicas = num_replicas
723
+ self.rank = rank
724
+ self.epoch = 0
725
+ self.drop_last = drop_last
726
+ # If the dataset length is evenly divisible by # of replicas, then there
727
+ # is no need to drop any data, since the dataset will be split equally.
728
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0:
729
+ # Split to nearest available length that is evenly divisible.
730
+ # This is to ensure each rank receives the same amount of data when
731
+ # using this Sampler.
732
+ self.num_samples = math.ceil(
733
+ (len(self.dataset) - self.num_replicas) / self.num_replicas
734
+ )
735
+ else:
736
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
737
+ self.total_size = self.num_samples * self.num_replicas
738
+ self.seed = seed
739
+ self.model_input_name = (
740
+ model_input_name if model_input_name is not None else "input_ids"
741
+ )
742
+
743
+ if lengths is None:
744
+ print("Lengths is none - calculating lengths.")
745
+ if (
746
+ not (
747
+ isinstance(dataset[0], dict)
748
+ or isinstance(dataset[0], BatchEncoding)
749
+ )
750
+ or self.model_input_name not in dataset[0]
751
+ ):
752
+ raise ValueError(
753
+ "Can only automatically infer lengths for datasets whose items are dictionaries with an "
754
+ f"'{self.model_input_name}' key."
755
+ )
756
+ lengths = [len(feature[self.model_input_name]) for feature in dataset]
757
+ self.lengths = lengths
758
+
759
+ def __iter__(self) -> Iterator:
760
+ # Deterministically shuffle based on epoch and seed
761
+ g = torch.Generator()
762
+ g.manual_seed(self.seed + self.epoch)
763
+
764
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
765
+
766
+ if not self.drop_last:
767
+ # add extra samples to make it evenly divisible
768
+ indices += indices[: (self.total_size - len(indices))]
769
+ else:
770
+ # remove tail of data to make it evenly divisible.
771
+ indices = indices[: self.total_size]
772
+ assert len(indices) == self.total_size
773
+
774
+ # subsample
775
+ indices = indices[self.rank : self.total_size : self.num_replicas]
776
+ assert len(indices) == self.num_samples
777
+
778
+ return iter(indices)
779
+
780
+
781
+ def get_length_grouped_indices(
782
+ lengths, batch_size, mega_batch_mult=None, generator=None
783
+ ):
784
+ """
785
+ Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
786
+ similar lengths. To do this, the indices are:
787
+
788
+ - randomly permuted
789
+ - grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
790
+ - sorted by length in each mega-batch
791
+
792
+ The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
793
+ maximum length placed first, so that an OOM happens sooner rather than later.
794
+ """
795
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
796
+ if mega_batch_mult is None:
797
+ # mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
798
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
799
+ # Just in case, for tiny datasets
800
+ if mega_batch_mult == 0:
801
+ mega_batch_mult = 1
802
+
803
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
804
+ indices = torch.randperm(len(lengths), generator=generator)
805
+ megabatch_size = mega_batch_mult * batch_size
806
+ megabatches = [
807
+ indices[i : i + megabatch_size].tolist()
808
+ for i in range(0, len(lengths), megabatch_size)
809
+ ]
810
+ megabatches = [
811
+ list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
812
+ for megabatch in megabatches
813
+ ]
814
+
815
+ # The rest is to get the biggest batch first.
816
+ # Since each megabatch is sorted by descending length, the longest element is the first
817
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
818
+ max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
819
+ # Switch to put the longest element in first position
820
+ megabatches[0][0], megabatches[max_idx][0] = (
821
+ megabatches[max_idx][0],
822
+ megabatches[0][0],
823
+ )
824
+
825
+ return [item for sublist in megabatches for item in sublist]
826
+
827
+
828
+ # from typing import Any, Tuple, Optional
829
+
830
+ # class CustomDataCollatorForMLM(DataCollatorForLanguageModeling):
831
+ # # def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
832
+
833
+ # # import torch
834
+
835
+ # # labels = inputs.clone()
836
+ # # # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
837
+ # # probability_matrix = torch.full(labels.shape, self.mlm_probability)
838
+ # # if special_tokens_mask is None:
839
+ # # special_tokens_mask = [
840
+ # # self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
841
+ # # ]
842
+ # # special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
843
+ # # else:
844
+ # # special_tokens_mask = special_tokens_mask.bool()
845
+ # # probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
846
+ # # masked_indices = torch.bernoulli(probability_matrix).bool()
847
+ # # labels[~masked_indices] = -100 # We only compute loss on masked tokens
848
+
849
+ # # # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
850
+ # # indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
851
+ # # inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
852
+
853
+ # # # 15% of the time, we replace masked input tokens with random word
854
+ # # indices_random = torch.bernoulli(torch.full(labels.shape, 0.75)).bool() & masked_indices & ~indices_replaced
855
+ # # random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
856
+ # # inputs[indices_random] = random_words[indices_random]
857
+
858
+ # # # The rest of the time (5% of the time) we keep the masked input tokens unchanged
859
+ # # return inputs, labels
860
+
861
+ # def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
862
+ # import torch
863
+
864
+ # labels = inputs.clone()
865
+ # probability_matrix = torch.full(labels.shape, self.mlm_probability)
866
+ # if special_tokens_mask is None:
867
+ # special_tokens_mask = [
868
+ # self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
869
+ # ]
870
+ # special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
871
+ # else:
872
+ # special_tokens_mask = special_tokens_mask.bool()
873
+
874
+ # probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
875
+ # masked_indices = torch.bernoulli(probability_matrix).bool()
876
+ # labels[~masked_indices] = -100 # We only compute loss on masked tokens
877
+
878
+ # # 100% of the time, replace masked input tokens with tokenizer.mask_token ([MASK])
879
+ # inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
880
+
881
+ # return inputs, labels
882
+
883
+ # class CustomGeneformerPretrainer(Trainer):
884
+ # def __init__(self, *args, **kwargs):
885
+ # data_collator = kwargs.get("data_collator",None)
886
+ # token_dictionary = kwargs.pop("token_dictionary")
887
+
888
+ # if data_collator is None:
889
+ # precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
890
+
891
+ # # # Data Collator Functions
892
+ # data_collator = CustomDataCollatorForMLM(
893
+ # tokenizer=precollator, mlm=True, mlm_probability=0.15
894
+ # )
895
+ # kwargs["data_collator"] = data_collator
896
+
897
+ # # load previously saved length vector for dataset to speed up LengthGroupedSampler
898
+ # # pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
899
+ # example_lengths_file = kwargs.pop("example_lengths_file")
900
+ # if example_lengths_file:
901
+ # with open(example_lengths_file, "rb") as f:
902
+ # self.example_lengths = pickle.load(f)
903
+ # else:
904
+ # raise Exception(
905
+ # "example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
906
+ # )
907
+ # super().__init__(*args, **kwargs)
908
+ # # self.exp_logits_dir = exp_logits_dir
909
+ # # self.min_exp_logits = float('inf')
910
+ # # self.max_exp_logits = float('-inf')
911
+
912
+ # # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
913
+ # def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
914
+ # if not isinstance(self.train_dataset, collections.abc.Sized):
915
+ # return None
916
+
917
+ # generator = None
918
+ # if self.args.world_size <= 1 and _is_torch_generator_available:
919
+ # generator = torch.Generator()
920
+ # generator.manual_seed(
921
+ # int(torch.empty((), dtype=torch.int64).random_().item())
922
+ # )
923
+
924
+ # # Build the sampler.
925
+ # if self.args.group_by_length:
926
+ # if is_datasets_available() and isinstance(self.train_dataset, Dataset):
927
+ # lengths = self.example_lengths
928
+ # else:
929
+ # lengths = None
930
+ # model_input_name = (
931
+ # self.tokenizer.model_input_names[0]
932
+ # if self.tokenizer is not None
933
+ # else None
934
+ # )
935
+ # if self.args.world_size <= 1:
936
+ # return LengthGroupedSampler(
937
+ # dataset=self.train_dataset,
938
+ # batch_size=self.args.train_batch_size,
939
+ # lengths=lengths,
940
+ # model_input_name=model_input_name,
941
+ # generator=generator,
942
+ # )
943
+ # else:
944
+ # return CustomDistributedLengthGroupedSampler(
945
+ # dataset=self.train_dataset,
946
+ # batch_size=self.args.train_batch_size,
947
+ # num_replicas=self.args.world_size,
948
+ # rank=self.args.process_index,
949
+ # lengths=lengths,
950
+ # model_input_name=model_input_name,
951
+ # seed=self.args.seed,
952
+ # )
953
+
954
+ # else:
955
+ # if self.args.world_size <= 1:
956
+ # if _is_torch_generator_available:
957
+ # return RandomSampler(self.train_dataset, generator=generator)
958
+ # return RandomSampler(self.train_dataset)
959
+ # elif (
960
+ # self.args.parallel_mode
961
+ # in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
962
+ # and not self.args.dataloader_drop_last
963
+ # ):
964
+ # # Use a loop for TPUs when drop_last is False to have all batches have the same size.
965
+ # return DistributedSamplerWithLoop(
966
+ # self.train_dataset,
967
+ # batch_size=self.args.per_device_train_batch_size,
968
+ # num_replicas=self.args.world_size,
969
+ # rank=self.args.process_index,
970
+ # seed=self.args.seed,
971
+ # )
972
+ # else:
973
+ # return DistributedSampler(
974
+ # self.train_dataset,
975
+ # num_replicas=self.args.world_size,
976
+ # rank=self.args.process_index,
977
+ # seed=self.args.seed,
978
+ # )
geneformer/token_dictionary.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab9dc40973fa5224d77b793e2fd114cacf3d08423ed9c4c49caf0ba9c7f218f1
3
+ size 788424
geneformer/tokenizer.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer tokenizer.
3
+
4
+ **Input data:**
5
+
6
+ | *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
7
+ | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
8
+ | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
9
+
10
+ | *Optional col (cell) attribute:* "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria.
11
+ | *Optional col (cell) attributes:* any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below.
12
+
13
+ **Usage:**
14
+
15
+ .. code-block :: python
16
+
17
+ >>> from geneformer import TranscriptomeTokenizer
18
+ >>> tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ"}, nproc=4)
19
+ >>> tk.tokenize_data("data_directory", "output_directory", "output_prefix")
20
+
21
+ **Description:**
22
+
23
+ | Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
24
+
25
+ | The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
26
+
27
+ | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
28
+
29
+ | No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes "cell_type" and "organ_major" and one would like to retain these attributes as labels in the tokenized dataset with the new names "cell_type" and "organ", respectively, the following custom attribute dictionary should be provided: {"cell_type": "cell_type", "organ_major": "organ"}.
30
+
31
+ | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
32
+
33
+ | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
34
+
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import logging
40
+ import pickle
41
+ import warnings
42
+ from pathlib import Path
43
+ from typing import Literal
44
+
45
+ import anndata as ad
46
+ import loompy as lp
47
+ import numpy as np
48
+ import scipy.sparse as sp
49
+ from datasets import Dataset
50
+
51
+ warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
52
+ logger = logging.getLogger(__name__)
53
+
54
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
55
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
56
+
57
+
58
+ def rank_genes(gene_vector, gene_tokens):
59
+ """
60
+ Rank gene expression vector.
61
+ """
62
+ # sort by median-scaled gene values
63
+ sorted_indices = np.argsort(-gene_vector)
64
+ return gene_tokens[sorted_indices]
65
+
66
+
67
+ def tokenize_cell(gene_vector, gene_tokens):
68
+ """
69
+ Convert normalized gene expression vector to tokenized rank value encoding.
70
+ """
71
+ # create array of gene vector with token indices
72
+ # mask undetected genes
73
+ nonzero_mask = np.nonzero(gene_vector)[0]
74
+ # rank by median-scaled gene values
75
+ return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
76
+
77
+
78
+ class TranscriptomeTokenizer:
79
+ def __init__(
80
+ self,
81
+ custom_attr_name_dict=None,
82
+ nproc=1,
83
+ chunk_size=512,
84
+ gene_median_file=GENE_MEDIAN_FILE,
85
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
86
+ ):
87
+ """
88
+ Initialize tokenizer.
89
+
90
+ **Parameters:**
91
+
92
+ custom_attr_name_dict : None, dict
93
+ | Dictionary of custom attributes to be added to the dataset.
94
+ | Keys are the names of the attributes in the loom file.
95
+ | Values are the names of the attributes in the dataset.
96
+ nproc : int
97
+ | Number of processes to use for dataset mapping.
98
+ chunk_size: int = 512
99
+ | Chunk size for anndata tokenizer.
100
+ gene_median_file : Path
101
+ | Path to pickle file containing dictionary of non-zero median
102
+ | gene expression values across Genecorpus-30M.
103
+ token_dictionary_file : Path
104
+ | Path to pickle file containing token dictionary (Ensembl IDs:token).
105
+ """
106
+ # dictionary of custom attributes {output dataset column name: input .loom column name}
107
+ self.custom_attr_name_dict = custom_attr_name_dict
108
+
109
+ # number of processes for dataset mapping
110
+ self.nproc = nproc
111
+
112
+ # chunk size for anndata tokenizer
113
+ self.chunk_size = chunk_size
114
+
115
+ # load dictionary of gene normalization factors
116
+ # (non-zero median value of expression across Genecorpus-30M)
117
+ with open(gene_median_file, "rb") as f:
118
+ self.gene_median_dict = pickle.load(f)
119
+
120
+ # load token dictionary (Ensembl IDs:token)
121
+ with open(token_dictionary_file, "rb") as f:
122
+ self.gene_token_dict = pickle.load(f)
123
+
124
+ # gene keys for full vocabulary
125
+ self.gene_keys = list(self.gene_median_dict.keys())
126
+
127
+ # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
128
+ self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
129
+
130
+ def tokenize_data(
131
+ self,
132
+ data_directory: Path | str,
133
+ output_directory: Path | str,
134
+ output_prefix: str,
135
+ file_format: Literal["loom", "h5ad"] = "loom",
136
+ use_generator: bool = False,
137
+ ):
138
+ """
139
+ Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
140
+
141
+ **Parameters:**
142
+
143
+ data_directory : Path
144
+ | Path to directory containing loom files or anndata files
145
+ output_directory : Path
146
+ | Path to directory where tokenized data will be saved as .dataset
147
+ output_prefix : str
148
+ | Prefix for output .dataset
149
+ file_format : str
150
+ | Format of input files. Can be "loom" or "h5ad".
151
+ use_generator : bool
152
+ | Whether to use generator or dict for tokenization.
153
+ """
154
+ tokenized_cells, cell_metadata = self.tokenize_files(
155
+ Path(data_directory), file_format
156
+ )
157
+ tokenized_dataset = self.create_dataset(
158
+ tokenized_cells, cell_metadata, use_generator=use_generator
159
+ )
160
+
161
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
162
+ tokenized_dataset.save_to_disk(output_path)
163
+
164
+ def tokenize_files(
165
+ self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"
166
+ ):
167
+ tokenized_cells = []
168
+ if self.custom_attr_name_dict is not None:
169
+ cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
170
+ cell_metadata = {
171
+ attr_key: [] for attr_key in self.custom_attr_name_dict.values()
172
+ }
173
+
174
+ # loops through directories to tokenize .loom files
175
+ file_found = 0
176
+ # loops through directories to tokenize .loom or .h5ad files
177
+ tokenize_file_fn = (
178
+ self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
179
+ )
180
+ for file_path in data_directory.glob(f"*.{file_format}"):
181
+ file_found = 1
182
+ print(f"Tokenizing {file_path}")
183
+ file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
184
+ tokenized_cells += file_tokenized_cells
185
+ if self.custom_attr_name_dict is not None:
186
+ for k in cell_attr:
187
+ cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[
188
+ k
189
+ ]
190
+ else:
191
+ cell_metadata = None
192
+
193
+ if file_found == 0:
194
+ logger.error(
195
+ f"No .{file_format} files found in directory {data_directory}."
196
+ )
197
+ raise
198
+ return tokenized_cells, cell_metadata
199
+
200
+ def tokenize_anndata(self, adata_file_path, target_sum=10_000):
201
+ adata = ad.read(adata_file_path, backed="r")
202
+
203
+ if self.custom_attr_name_dict is not None:
204
+ file_cell_metadata = {
205
+ attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
206
+ }
207
+
208
+ coding_miRNA_loc = np.where(
209
+ [self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]]
210
+ )[0]
211
+ norm_factor_vector = np.array(
212
+ [
213
+ self.gene_median_dict[i]
214
+ for i in adata.var["ensembl_id"][coding_miRNA_loc]
215
+ ]
216
+ )
217
+ coding_miRNA_ids = adata.var["ensembl_id"][coding_miRNA_loc]
218
+ coding_miRNA_tokens = np.array(
219
+ [self.gene_token_dict[i] for i in coding_miRNA_ids]
220
+ )
221
+
222
+ try:
223
+ _ = adata.obs["filter_pass"]
224
+ except KeyError:
225
+ var_exists = False
226
+ else:
227
+ var_exists = True
228
+
229
+ if var_exists:
230
+ filter_pass_loc = np.where([i == 1 for i in adata.obs["filter_pass"]])[0]
231
+ elif not var_exists:
232
+ print(
233
+ f"{adata_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
234
+ )
235
+ filter_pass_loc = np.array([i for i in range(adata.shape[0])])
236
+
237
+ tokenized_cells = []
238
+
239
+ for i in range(0, len(filter_pass_loc), self.chunk_size):
240
+ idx = filter_pass_loc[i : i + self.chunk_size]
241
+
242
+ n_counts = adata[idx].obs["n_counts"].values[:, None]
243
+ X_view = adata[idx, coding_miRNA_loc].X
244
+ X_norm = X_view / n_counts * target_sum / norm_factor_vector
245
+ X_norm = sp.csr_matrix(X_norm)
246
+
247
+ tokenized_cells += [
248
+ rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
249
+ for i in range(X_norm.shape[0])
250
+ ]
251
+
252
+ # add custom attributes for subview to dict
253
+ if self.custom_attr_name_dict is not None:
254
+ for k in file_cell_metadata.keys():
255
+ file_cell_metadata[k] += adata[idx].obs[k].tolist()
256
+ else:
257
+ file_cell_metadata = None
258
+
259
+ return tokenized_cells, file_cell_metadata
260
+
261
+ def tokenize_loom(self, loom_file_path, target_sum=10_000):
262
+ if self.custom_attr_name_dict is not None:
263
+ file_cell_metadata = {
264
+ attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
265
+ }
266
+
267
+ with lp.connect(str(loom_file_path)) as data:
268
+ # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
269
+ coding_miRNA_loc = np.where(
270
+ [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
271
+ )[0]
272
+ norm_factor_vector = np.array(
273
+ [
274
+ self.gene_median_dict[i]
275
+ for i in data.ra["ensembl_id"][coding_miRNA_loc]
276
+ ]
277
+ )
278
+ coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]
279
+ coding_miRNA_tokens = np.array(
280
+ [self.gene_token_dict[i] for i in coding_miRNA_ids]
281
+ )
282
+
283
+ # define coordinates of cells passing filters for inclusion (e.g. QC)
284
+ try:
285
+ data.ca["filter_pass"]
286
+ except AttributeError:
287
+ var_exists = False
288
+ else:
289
+ var_exists = True
290
+
291
+ if var_exists:
292
+ filter_pass_loc = np.where([i == 1 for i in data.ca["filter_pass"]])[0]
293
+ elif not var_exists:
294
+ print(
295
+ f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
296
+ )
297
+ filter_pass_loc = np.array([i for i in range(data.shape[1])])
298
+
299
+ # scan through .loom files and tokenize cells
300
+ tokenized_cells = []
301
+ for _ix, _selection, view in data.scan(
302
+ items=filter_pass_loc, axis=1, batch_size=self.chunk_size
303
+ ):
304
+ # select subview with protein-coding and miRNA genes
305
+ subview = view.view[coding_miRNA_loc, :]
306
+
307
+ # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision
308
+ # and normalize by gene normalization factors
309
+ subview_norm_array = (
310
+ subview[:, :]
311
+ / subview.ca.n_counts
312
+ * target_sum
313
+ / norm_factor_vector[:, None]
314
+ )
315
+ # tokenize subview gene vectors
316
+ tokenized_cells += [
317
+ tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens)
318
+ for i in range(subview_norm_array.shape[1])
319
+ ]
320
+
321
+ # add custom attributes for subview to dict
322
+ if self.custom_attr_name_dict is not None:
323
+ for k in file_cell_metadata.keys():
324
+ file_cell_metadata[k] += subview.ca[k].tolist()
325
+ else:
326
+ file_cell_metadata = None
327
+
328
+ return tokenized_cells, file_cell_metadata
329
+
330
+ def create_dataset(
331
+ self,
332
+ tokenized_cells,
333
+ cell_metadata,
334
+ use_generator=False,
335
+ keep_uncropped_input_ids=False,
336
+ ):
337
+ print("Creating dataset.")
338
+ # create dict for dataset creation
339
+ dataset_dict = {"input_ids": tokenized_cells}
340
+ if self.custom_attr_name_dict is not None:
341
+ dataset_dict.update(cell_metadata)
342
+
343
+ # create dataset
344
+ if use_generator:
345
+
346
+ def dict_generator():
347
+ for i in range(len(tokenized_cells)):
348
+ yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
349
+
350
+ output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
351
+ else:
352
+ output_dataset = Dataset.from_dict(dataset_dict)
353
+
354
+ def format_cell_features(example):
355
+ # Store original uncropped input_ids in separate feature
356
+ if keep_uncropped_input_ids:
357
+ example["input_ids_uncropped"] = example["input_ids"]
358
+ example["length_uncropped"] = len(example["input_ids"])
359
+
360
+ # Truncate/Crop input_ids to size 2,048
361
+ example["input_ids"] = example["input_ids"][0:2048]
362
+ example["length"] = len(example["input_ids"])
363
+
364
+ return example
365
+
366
+ output_dataset_truncated = output_dataset.map(
367
+ format_cell_features, num_proc=self.nproc
368
+ )
369
+ return output_dataset_truncated