aletlvl commited on
Commit
016773e
·
verified ·
1 Parent(s): f7999b5

Upload tokenization_nicheformer.py

Browse files
Files changed (1) hide show
  1. tokenization_nicheformer.py +277 -346
tokenization_nicheformer.py CHANGED
@@ -1,399 +1,330 @@
1
- from transformers import PreTrainedTokenizer
2
- import numpy as np
3
  from typing import List, Dict, Optional, Union, Tuple
 
 
 
 
 
 
 
4
  import os
5
  import json
6
- from anndata import AnnData
7
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class NicheformerTokenizer(PreTrainedTokenizer):
10
- """
11
- Tokenizer for Nicheformer models.
12
 
13
- This tokenizer converts gene expression data from AnnData objects into token IDs
14
- for the Nicheformer model. It also handles special tokens for modality, species,
15
- and assay information extracted from the observation columns.
16
- """
17
-
18
- vocab_files_names = {"vocab_file": "vocab.json"}
19
  model_input_names = ["input_ids", "attention_mask"]
 
 
 
 
 
20
 
21
  def __init__(
22
  self,
23
  vocab_file=None,
24
- max_seq_len=4096,
25
- aux_tokens=30,
 
 
26
  **kwargs
27
  ):
28
- """
29
- Initialize the tokenizer.
30
-
31
- Args:
32
- vocab_file: Path to the vocabulary file
33
- max_seq_len: Maximum sequence length
34
- aux_tokens: Number of auxiliary tokens reserved
35
- """
36
- # Initialize vocabulary
37
- self.vocab = {}
38
- self.ids_to_tokens = {}
39
-
40
- # Load vocabulary if provided
41
- if vocab_file is not None and os.path.isfile(vocab_file):
42
- with open(vocab_file, 'r', encoding='utf-8') as f:
43
- self.vocab = json.load(f)
44
- self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
45
-
46
- # Initialize the parent class
47
- super().__init__(
48
- pad_token="<pad>",
49
- eos_token="<eos>",
50
- unk_token="",
51
- **kwargs
52
- )
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- self.max_seq_len = max_seq_len
55
  self.aux_tokens = aux_tokens
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Define token constants to match Nicheformer
58
- self._pad_token_id = 0
 
 
59
 
60
- # Define special token mappings
61
- self.modality_dict = {
62
- 'dissociated': 3,
63
- 'spatial': 4,
64
- }
 
 
 
 
 
 
65
 
66
- self.specie_dict = {
67
- 'human': 5,
68
- 'Homo sapiens': 5,
69
- 'Mus musculus': 6,
70
- 'mouse': 6,
71
- }
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- self.technology_dict = {
74
- "merfish": 7,
75
- "MERFISH": 7,
76
- "cosmx": 8,
77
- "visium": 9,
78
- "10x 5' v2": 10,
79
- "10x 3' v3": 11,
80
- "10x 3' v2": 12,
81
- "10x 5' v1": 13,
82
- "10x 3' v1": 14,
83
- "10x 3' transcription profiling": 15,
84
- "10x transcription profiling": 15,
85
- "10x 5' transcription profiling": 16,
86
- "CITE-seq": 17,
87
- "Smart-seq v4": 18,
88
- }
89
-
90
- def get_vocab(self) -> Dict[str, int]:
91
- """Return the vocabulary as a dictionary of token to index."""
92
- if not self.vocab:
93
- # If vocab is empty, create a minimal vocab with special tokens
94
- vocab = {}
95
- # Add special tokens
96
- vocab["<pad>"] = 0
97
- vocab["<eos>"] = 1
98
- vocab[""] = 2
99
- # Add modality tokens
100
- for token, idx in self.modality_dict.items():
101
- vocab[token] = idx
102
- # Add species tokens
103
- for token, idx in self.specie_dict.items():
104
- vocab[token] = idx
105
- # Add technology tokens
106
- for token, idx in self.technology_dict.items():
107
- vocab[token] = idx
108
- return vocab
109
- return self.vocab
110
-
111
- def _tokenize(self, text):
112
- """
113
- Not used for gene expression data, but required by the interface.
114
- """
115
- return [text]
116
-
117
- def _convert_token_to_id(self, token):
118
- """Convert a token to an ID using the vocab."""
119
- return self.vocab.get(token, self.vocab.get(self.unk_token))
120
-
121
- def _convert_id_to_token(self, index):
122
- """Convert an ID to a token using the vocab."""
123
- return self.ids_to_tokens.get(index, self.unk_token)
124
-
125
- def convert_tokens_to_string(self, tokens):
126
- """
127
- Not used for gene expression data, but required by the interface.
128
- """
129
- return " ".join(tokens)
130
-
131
- def save_vocabulary(self, save_directory, filename_prefix=None):
132
  """Save the vocabulary to a file."""
133
  vocab_file = os.path.join(
134
  save_directory,
135
- (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
136
  )
137
 
138
  with open(vocab_file, "w", encoding="utf-8") as f:
139
- json.dump(self.vocab, f, ensure_ascii=False)
140
 
141
  return (vocab_file,)
142
-
143
- def _sf_normalize(self, X):
144
- """Normalize the input matrix to a scale of 10000."""
145
- X = X.copy()
146
- counts = np.array(X.sum(axis=1))
147
- # avoid zero division error
148
- counts += counts == 0.
149
- # normalize to 10000 counts
150
- scaling_factor = 10000. / counts
151
-
152
- from scipy.sparse import issparse
153
- if issparse(X):
154
- from sklearn.utils import sparsefuncs
155
- sparsefuncs.inplace_row_scale(X, scaling_factor)
156
- else:
157
- np.multiply(X, scaling_factor.reshape((-1, 1)), out=X)
158
-
159
- return X
160
-
161
- def _sub_tokenize_data(self, x):
162
- """Tokenize the input gene vector"""
163
- from scipy.sparse import issparse
164
-
165
- if issparse(x):
166
- x = x.toarray()
167
-
168
- n_cells, n_genes = x.shape
169
- scores_final = np.zeros((n_cells, self.max_seq_len), dtype=np.int32)
170
-
171
- for i, cell in enumerate(x):
172
- nonzero_mask = np.nonzero(cell)[0]
173
- sorted_indices = nonzero_mask[np.argsort(-cell[nonzero_mask])][:self.max_seq_len]
174
- sorted_indices = sorted_indices + self.aux_tokens # reserve tokens for padding etc
175
-
176
- scores = np.zeros(self.max_seq_len, dtype=np.int32)
177
- scores[:len(sorted_indices)] = sorted_indices.astype(np.int32)
178
-
179
- scores_final[i, :] = scores
180
-
181
- return scores_final
182
-
183
- def tokenize_anndata(self, adata, median_counts_per_gene=None, subset_obs=None):
184
- """
185
- Tokenize gene expression data from an AnnData object.
186
 
187
  Args:
188
- adata: AnnData object containing gene expression data
189
- median_counts_per_gene: Median counts per gene for normalization
190
- subset_obs: Indices or boolean mask to subset observations
191
 
192
  Returns:
193
- Dictionary with tokenized data
194
  """
195
- # Subset data if requested
196
- if subset_obs is not None:
197
- adata = adata[subset_obs].copy()
198
- else:
199
- adata = adata.copy()
200
-
201
- # Extract expression matrix
202
- X = adata.X
203
-
204
- # Normalize data
205
- X = np.nan_to_num(X) if not isinstance(X, np.ndarray) or not np.issubdtype(X.dtype, np.integer) else X
206
- X = self._sf_normalize(X)
207
-
208
- if median_counts_per_gene is not None:
209
- median_counts_per_gene = median_counts_per_gene.copy()
210
- median_counts_per_gene += median_counts_per_gene == 0
211
 
212
- if isinstance(X, np.ndarray):
213
- X = X / median_counts_per_gene.reshape((1, -1))
214
- else:
215
- # For sparse matrices, we need to handle this differently
216
- from scipy.sparse import issparse
217
- if issparse(X):
218
- X = X.toarray() / median_counts_per_gene.reshape((1, -1))
219
-
220
- # Tokenize
221
- tokens = self._sub_tokenize_data(X)
222
-
223
- # Create attention mask (1 for real tokens, 0 for padding)
224
- attention_mask = (tokens != self._pad_token_id).astype(np.int32)
225
-
226
- # Extract metadata from obs
227
- result = {
228
- "input_ids": tokens,
229
- "attention_mask": attention_mask
230
- }
231
 
232
- # Extract modality, specie, and assay from obs if available
233
- if 'modality' in adata.obs.columns:
234
- if adata.obs['modality'].dtype == 'object':
235
- # Convert string values to token IDs
236
- modality_ids = np.array([self.modality_dict.get(m, 0) for m in adata.obs['modality']])
237
- else:
238
- # Assume already tokenized
239
- modality_ids = adata.obs['modality'].values
240
- result["modality"] = modality_ids
241
-
242
- if 'specie' in adata.obs.columns:
243
- if adata.obs['specie'].dtype == 'object':
244
- specie_ids = np.array([self.specie_dict.get(s, 0) for s in adata.obs['specie']])
245
- else:
246
- specie_ids = adata.obs['specie'].values
247
- result["specie"] = specie_ids
248
-
249
- if 'assay' in adata.obs.columns:
250
- if adata.obs['assay'].dtype == 'object':
251
- assay_ids = np.array([self.technology_dict.get(a, 0) for a in adata.obs['assay']])
252
- else:
253
- assay_ids = adata.obs['assay'].values
254
- result["assay"] = assay_ids
255
 
256
- return result
257
 
258
- def batch_encode_plus(
259
  self,
260
- adata=None,
261
- expression_matrix=None,
262
- median_counts_per_gene=None,
263
- modality=None,
264
- specie=None,
265
- assay=None,
266
- subset_obs=None,
267
- return_tensors=None,
268
  **kwargs
269
- ):
270
- """
271
- Encode a batch of gene expression data.
272
 
273
  Args:
274
- adata: AnnData object containing gene expression data
275
- expression_matrix: Matrix of gene expression values (cells x genes)
276
- median_counts_per_gene: Median counts per gene for normalization
277
- modality: List or array of modality values
278
- specie: List or array of species values
279
- assay: List or array of assay/technology values
280
- subset_obs: Indices or boolean mask to subset observations in adata
281
- return_tensors: Format of the returned tensors ("pt" for PyTorch, "tf" for TensorFlow, None for numpy)
282
 
283
  Returns:
284
- Dictionary with encoded data
285
  """
286
  if adata is not None:
287
- # Use AnnData object
288
- result = self.tokenize_anndata(adata, median_counts_per_gene, subset_obs)
289
-
290
- # Override metadata if explicitly provided
291
- if modality is not None:
292
- if isinstance(modality[0], str):
293
- modality_ids = np.array([self.modality_dict.get(m, 0) for m in modality])
294
- else:
295
- modality_ids = np.array(modality)
296
- result["modality"] = modality_ids
297
 
298
- if specie is not None:
299
- if isinstance(specie[0], str):
300
- specie_ids = np.array([self.specie_dict.get(s, 0) for s in specie])
301
- else:
302
- specie_ids = np.array(specie)
303
- result["specie"] = specie_ids
 
304
 
305
- if assay is not None:
306
- if isinstance(assay[0], str):
307
- assay_ids = np.array([self.technology_dict.get(a, 0) for a in assay])
308
- else:
309
- assay_ids = np.array(assay)
310
- result["assay"] = assay_ids
311
-
312
- elif expression_matrix is not None:
313
- # Use raw expression matrix
314
- from scipy.sparse import issparse
315
-
316
- # Convert to numpy array if sparse
317
- if issparse(expression_matrix):
318
- expression_matrix = expression_matrix.toarray()
319
-
320
- # Normalize data
321
- expression_matrix = np.nan_to_num(expression_matrix)
322
- expression_matrix = self._sf_normalize(expression_matrix)
323
-
324
- if median_counts_per_gene is not None:
325
- median_counts_per_gene = median_counts_per_gene.copy()
326
- median_counts_per_gene += median_counts_per_gene == 0
327
- expression_matrix = expression_matrix / median_counts_per_gene.reshape((1, -1))
328
 
329
- # Tokenize
330
- tokens = self._sub_tokenize_data(expression_matrix)
 
 
 
 
 
 
 
 
 
331
 
332
- # Create attention mask (1 for real tokens, 0 for padding)
333
- attention_mask = (tokens != self._pad_token_id).astype(np.int32)
 
334
 
335
- # Add metadata tokens if provided
336
- result = {
337
- "input_ids": tokens,
338
- "attention_mask": attention_mask
339
- }
340
 
341
- if modality is not None:
342
- if isinstance(modality[0], str):
343
- modality_ids = np.array([self.modality_dict.get(m, 0) for m in modality])
344
- else:
345
- modality_ids = np.array(modality)
346
- result["modality"] = modality_ids
347
-
348
- if specie is not None:
349
- if isinstance(specie[0], str):
350
- specie_ids = np.array([self.specie_dict.get(s, 0) for s in specie])
351
- else:
352
- specie_ids = np.array(specie)
353
- result["specie"] = specie_ids
354
-
355
- if assay is not None:
356
- if isinstance(assay[0], str):
357
- assay_ids = np.array([self.technology_dict.get(a, 0) for a in assay])
358
- else:
359
- assay_ids = np.array(assay)
360
- result["assay"] = assay_ids
361
 
362
- else:
363
- raise ValueError("Either adata or expression_matrix must be provided")
364
-
365
- # Convert to tensors if requested
366
- if return_tensors == "pt":
367
- result = {k: torch.tensor(v) for k, v in result.items()}
368
- # Otherwise keep as numpy arrays
369
 
370
- return result
 
 
 
371
 
372
- def __call__(
373
- self,
374
- adata=None,
375
- expression_matrix=None,
376
- median_counts_per_gene=None,
377
- modality=None,
378
- specie=None,
379
- assay=None,
380
- return_tensors=None,
381
- subset_obs=None,
382
- **kwargs
383
- ):
384
- """
385
- Encode gene expression data.
386
-
387
- This is a convenience wrapper around batch_encode_plus.
388
- """
389
- return self.batch_encode_plus(
390
- adata=adata,
391
- expression_matrix=expression_matrix,
392
- median_counts_per_gene=median_counts_per_gene,
393
- modality=modality,
394
- specie=specie,
395
- assay=assay,
396
- return_tensors=return_tensors,
397
- subset_obs=subset_obs,
398
- **kwargs
399
- )
 
 
 
1
  from typing import List, Dict, Optional, Union, Tuple
2
+ import numpy as np
3
+ from transformers import PreTrainedTokenizer
4
+ from dataclasses import dataclass
5
+ import torch
6
+ import anndata as ad
7
+ from scipy.sparse import issparse
8
+ import numba
9
  import os
10
  import json
11
+
12
+ # Token IDs must match exactly with the original implementation
13
+ PAD_TOKEN = 0
14
+ MASK_TOKEN = 1
15
+ CLS_TOKEN = 2
16
+
17
+ # These mappings preserve the exact token IDs from the original implementation
18
+ MODALITY_DICT = {
19
+ 'dissociated': 3,
20
+ 'spatial': 4,
21
+ }
22
+
23
+ SPECIES_DICT = {
24
+ 'human': 5,
25
+ 'Homo sapiens': 5,
26
+ 'Mus musculus': 6,
27
+ 'mouse': 6,
28
+ }
29
+
30
+ TECHNOLOGY_DICT = {
31
+ "merfish": 7,
32
+ "MERFISH": 7,
33
+ "cosmx": 8,
34
+ "NanoString digital spatial profiling": 8,
35
+ "Xenium": 9,
36
+ "10x 5' v2": 10,
37
+ "10x 3' v3": 11,
38
+ "10x 3' v2": 12,
39
+ "10x 5' v1": 13,
40
+ "10x 3' v1": 14,
41
+ "10x 3' transcription profiling": 15,
42
+ "10x transcription profiling": 15,
43
+ "10x 5' transcription profiling": 16,
44
+ "CITE-seq": 17,
45
+ "Smart-seq v4": 18,
46
+ }
47
+
48
+ def sf_normalize(X: np.ndarray) -> np.ndarray:
49
+ """Size factor normalize to 10k counts."""
50
+ X = X.copy()
51
+ counts = np.array(X.sum(axis=1))
52
+ # avoid zero division error
53
+ counts += counts == 0.
54
+ # normalize to 10000 counts
55
+ scaling_factor = 10000. / counts
56
+
57
+ if issparse(X):
58
+ from scipy.sparse import sparsefuncs
59
+ sparsefuncs.inplace_row_scale(X, scaling_factor)
60
+ else:
61
+ np.multiply(X, scaling_factor.reshape((-1, 1)), out=X)
62
+
63
+ return X
64
+
65
+ @numba.jit(nopython=True, nogil=True)
66
+ def _sub_tokenize_data(x: np.ndarray, max_seq_len: int = -1, aux_tokens: int = 30) -> np.ndarray:
67
+ """Tokenize the input gene vector."""
68
+ scores_final = np.empty((x.shape[0], max_seq_len if max_seq_len > 0 else x.shape[1]))
69
+ for i, cell in enumerate(x):
70
+ nonzero_mask = np.nonzero(cell)[0]
71
+ sorted_indices = nonzero_mask[np.argsort(-cell[nonzero_mask])][:max_seq_len]
72
+ sorted_indices = sorted_indices + aux_tokens
73
+ if max_seq_len:
74
+ scores = np.zeros(max_seq_len, dtype=np.int32)
75
+ else:
76
+ scores = np.zeros_like(cell, dtype=np.int32)
77
+ scores[:len(sorted_indices)] = sorted_indices.astype(np.int32)
78
+ scores_final[i, :] = scores
79
+ return scores_final
80
 
81
  class NicheformerTokenizer(PreTrainedTokenizer):
82
+ """Tokenizer for Nicheformer that handles single-cell data."""
 
83
 
 
 
 
 
 
 
84
  model_input_names = ["input_ids", "attention_mask"]
85
+ vocab_files_names = {"vocab_file": "vocab.json"}
86
+
87
+ modality_dict = MODALITY_DICT
88
+ species_dict = SPECIES_DICT
89
+ technology_dict = TECHNOLOGY_DICT
90
 
91
  def __init__(
92
  self,
93
  vocab_file=None,
94
+ max_length: int = 1500,
95
+ aux_tokens: int = 30,
96
+ median_counts_per_gene: Optional[np.ndarray] = None,
97
+ gene_names: Optional[List[str]] = None,
98
  **kwargs
99
  ):
100
+ # Initialize base vocabulary
101
+ self._vocabulary = {
102
+ "[PAD]": PAD_TOKEN,
103
+ "[MASK]": MASK_TOKEN,
104
+ "[CLS]": CLS_TOKEN,
105
+ }
106
+
107
+ if vocab_file is not None:
108
+ with open(vocab_file, 'r') as f:
109
+ self._vocabulary.update(json.load(f))
110
+ else:
111
+ # Add modality tokens
112
+ for name, idx in self.modality_dict.items():
113
+ self._vocabulary[f"[MODALITY_{name}]"] = idx
114
+ # Add species tokens
115
+ for name, idx in self.species_dict.items():
116
+ if name in ["Homo sapiens", "Mus musculus"]:
117
+ continue # Skip redundant names
118
+ self._vocabulary[f"[SPECIES_{name}]"] = idx
119
+ # Add technology tokens
120
+ for name, idx in self.technology_dict.items():
121
+ if name in ["MERFISH", "10x transcription profiling"]:
122
+ continue # Skip redundant names
123
+ clean_name = name.lower().replace(" ", "_").replace("'", "_")
124
+ self._vocabulary[f"[TECH_{clean_name}]"] = idx
125
+
126
+ # Add gene tokens if provided
127
+ if gene_names is not None:
128
+ for i, gene in enumerate(gene_names):
129
+ self._vocabulary[gene] = i + aux_tokens
130
+ # Save vocabulary
131
+ os.makedirs('to_hf', exist_ok=True)
132
+ with open('to_hf/vocab.json', 'w') as f:
133
+ json.dump(self._vocabulary, f, indent=4)
134
+
135
+ super().__init__(**kwargs)
136
 
137
+ self.max_length = max_length
138
  self.aux_tokens = aux_tokens
139
+ self.median_counts_per_gene = median_counts_per_gene
140
+ self.gene_names = gene_names
141
+
142
+ # Set up special token mappings
143
+ self._pad_token = "[PAD]"
144
+ self._mask_token = "[MASK]"
145
+ self._cls_token = "[CLS]"
146
+
147
+ def get_vocab(self) -> Dict[str, int]:
148
+ """Returns the vocabulary mapping."""
149
+ return self._vocabulary.copy()
150
 
151
+ def _tokenize(self, text: str) -> List[str]:
152
+ """Tokenize text input."""
153
+ # This tokenizer doesn't handle text input directly
154
+ raise NotImplementedError("This tokenizer only works with gene expression data")
155
 
156
+ def _convert_token_to_id(self, token: str) -> int:
157
+ """Convert token to ID."""
158
+ # First check special token mappings
159
+ if token in self.modality_dict:
160
+ return self.modality_dict[token]
161
+ if token in self.species_dict:
162
+ return self.species_dict[token]
163
+ if token in self.technology_dict:
164
+ return self.technology_dict[token]
165
+ # Then check vocabulary
166
+ return self._vocabulary.get(token, self._vocabulary["[PAD]"])
167
 
168
+ def _convert_id_to_token(self, index: int) -> str:
169
+ """Convert ID to token."""
170
+ # First check special token mappings
171
+ for token, idx in self.modality_dict.items():
172
+ if idx == index:
173
+ return token
174
+ for token, idx in self.species_dict.items():
175
+ if idx == index:
176
+ return token
177
+ for token, idx in self.technology_dict.items():
178
+ if idx == index:
179
+ return token
180
+ # Then check vocabulary
181
+ for token, idx in self._vocabulary.items():
182
+ if idx == index:
183
+ return token
184
+ return "[PAD]"
185
 
186
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  """Save the vocabulary to a file."""
188
  vocab_file = os.path.join(
189
  save_directory,
190
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
191
  )
192
 
193
  with open(vocab_file, "w", encoding="utf-8") as f:
194
+ json.dump(self._vocabulary, f, ensure_ascii=False)
195
 
196
  return (vocab_file,)
197
+
198
+ def _tokenize_gene_expression(self, x: np.ndarray) -> np.ndarray:
199
+ """Tokenize gene expression matrix.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  Args:
202
+ x: Gene expression matrix (cells x genes)
 
 
203
 
204
  Returns:
205
+ Tokenized matrix
206
  """
207
+ # Handle sparse input
208
+ if issparse(x):
209
+ x = x.toarray()
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ # Normalize and scale
212
+ x = np.nan_to_num(x)
213
+ x = sf_normalize(x)
214
+ if self.median_counts_per_gene is not None:
215
+ median_counts = self.median_counts_per_gene.copy()
216
+ median_counts += median_counts == 0
217
+ x = x / median_counts.reshape((1, -1))
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
+ # Convert to tokens
220
+ tokens = _sub_tokenize_data(x, self.max_length, self.aux_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
+ return tokens.astype(np.int32)
223
 
224
+ def __call__(
225
  self,
226
+ adata: Optional[ad.AnnData] = None,
227
+ gene_expression: Optional[Union[np.ndarray, List[float]]] = None,
228
+ modality: Optional[str] = None,
229
+ species: Optional[str] = None,
230
+ technology: Optional[str] = None,
 
 
 
231
  **kwargs
232
+ ) -> Dict[str, torch.Tensor]:
233
+ """Convert inputs to model inputs.
 
234
 
235
  Args:
236
+ adata: AnnData object
237
+ gene_expression: Gene expression matrix if not using AnnData
238
+ modality: Modality type
239
+ species: Species type
240
+ technology: Technology/assay type
 
 
 
241
 
242
  Returns:
243
+ Dictionary with model inputs
244
  """
245
  if adata is not None:
246
+ # Get expression matrix
247
+ if issparse(adata.X):
248
+ x = adata.X.toarray()
249
+ else:
250
+ x = adata.X
 
 
 
 
 
251
 
252
+ # Get metadata for each cell if not provided
253
+ if modality is None and 'modality' in adata.obs:
254
+ modality = adata.obs['modality'].values
255
+ if species is None and 'specie' in adata.obs:
256
+ species = adata.obs['specie'].values
257
+ if technology is None and 'assay' in adata.obs:
258
+ technology = adata.obs['assay'].values
259
 
260
+ elif gene_expression is not None:
261
+ x = np.array(gene_expression)
262
+ if len(x.shape) == 1:
263
+ x = x.reshape(1, -1)
264
+ # For single gene expression input, convert scalar metadata to arrays
265
+ if modality is not None:
266
+ modality = np.array([modality])
267
+ if species is not None:
268
+ species = np.array([species])
269
+ if technology is not None:
270
+ technology = np.array([technology])
271
+ else:
272
+ raise ValueError("Either adata or gene_expression must be provided")
 
 
 
 
 
 
 
 
 
 
273
 
274
+ # Tokenize gene expression
275
+ token_ids = self._tokenize_gene_expression(x)
276
+ n_cells = token_ids.shape[0]
277
+
278
+ # Add special tokens for each cell
279
+ special_tokens = np.zeros((n_cells, 3), dtype=np.int32) # 3 for modality, species, technology
280
+ special_token_mask = np.zeros((n_cells, 3), dtype=bool) # Track which tokens are actually present
281
+
282
+ if modality is not None:
283
+ special_tokens[:, 0] = [self.modality_dict.get(m, self._vocabulary["[PAD]"]) for m in modality]
284
+ special_token_mask[:, 0] = True
285
 
286
+ if species is not None:
287
+ special_tokens[:, 1] = [self.species_dict.get(s, self._vocabulary["[PAD]"]) for s in species]
288
+ special_token_mask[:, 1] = True
289
 
290
+ if technology is not None:
291
+ special_tokens[:, 2] = [self.technology_dict.get(t, self._vocabulary["[PAD]"]) for t in technology]
292
+ special_token_mask[:, 2] = True
 
 
293
 
294
+ # Only keep the special tokens that are present (have True in mask)
295
+ special_tokens = special_tokens[:, special_token_mask[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
+ if special_tokens.size > 0:
298
+ token_ids = np.concatenate([special_tokens, token_ids[:, :(self.max_length - special_tokens.shape[1])]], axis=1)
299
+
300
+ # Create attention mask
301
+ attention_mask = (token_ids != self._vocabulary["[PAD]"])
 
 
302
 
303
+ return {
304
+ "input_ids": torch.tensor(token_ids, dtype=torch.long),
305
+ "attention_mask": torch.tensor(attention_mask)
306
+ }
307
 
308
+ def get_vocab_size(self) -> int:
309
+ """Get vocabulary size."""
310
+ if self.gene_names is not None:
311
+ return len(self.gene_names) + self.aux_tokens
312
+ return max(
313
+ max(self.modality_dict.values()),
314
+ max(self.species_dict.values()),
315
+ max(self.technology_dict.values())
316
+ ) + 1
317
+
318
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
319
+ """Convert a sequence of tokens to a string. Not used for gene expression."""
320
+ raise NotImplementedError("This tokenizer only works with gene expression data")
321
+
322
+ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
323
+ """Build model inputs from a sequence by adding special tokens."""
324
+ # For gene expression data, special tokens are handled in __call__
325
+ return token_ids_0
326
+
327
+ def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False) -> List[int]:
328
+ """Get list where entries are [1] if a token is [special] else [0]."""
329
+ # Consider tokens < aux_tokens as special
330
+ return [1 if token_id < self.aux_tokens else 0 for token_id in token_ids_0]