aletlvl commited on
Commit
5a0555b
·
verified ·
1 Parent(s): f495fe3

Tokenization fixed

Browse files
Files changed (1) hide show
  1. tokenization_nicheformer.py +58 -66
tokenization_nicheformer.py CHANGED
@@ -9,6 +9,7 @@ import numba
9
  import os
10
  import json
11
  from huggingface_hub import hf_hub_download
 
12
 
13
  # Token IDs must match exactly with the original implementation
14
  PAD_TOKEN = 0
@@ -236,89 +237,80 @@ class NicheformerTokenizer(PreTrainedTokenizer):
236
 
237
  return tokens.astype(np.int32)
238
 
239
- def __call__(
240
- self,
241
- adata: Optional[ad.AnnData] = None,
242
- gene_expression: Optional[Union[np.ndarray, List[float]]] = None,
243
- modality: Optional[str] = None,
244
- species: Optional[str] = None,
245
- technology: Optional[str] = None,
246
- **kwargs
247
- ) -> Dict[str, torch.Tensor]:
248
- """Convert inputs to model inputs.
249
 
250
  Args:
251
- adata: AnnData object
252
- gene_expression: Gene expression matrix if not using AnnData
253
- modality: Modality type
254
- species: Species type
255
- technology: Technology/assay type
256
 
257
  Returns:
258
- Dictionary with model inputs
259
  """
260
- if adata is not None:
261
- # Align with reference model if needed
262
- reference_model = self._load_reference_model()
263
- if reference_model is not None:
264
- # Concatenate and then remove the reference
265
- adata = ad.concat([reference_model, adata], join='outer', axis=0)
266
- adata = adata[1:]
267
-
268
- # Get expression matrix
269
- if issparse(adata.X):
270
- x = adata.X.toarray()
 
 
 
 
 
 
 
 
 
 
 
271
  else:
272
- x = adata.X
273
-
274
- # Get metadata for each cell if not provided
275
- if modality is None and 'modality' in adata.obs:
276
- modality = adata.obs['modality'].values
277
- if species is None and 'specie' in adata.obs:
278
- species = adata.obs['specie'].values
279
- if technology is None and 'assay' in adata.obs:
280
- technology = adata.obs['assay'].values
281
-
282
- elif gene_expression is not None:
283
- x = np.array(gene_expression)
284
- if len(x.shape) == 1:
285
- x = x.reshape(1, -1)
286
- # For single gene expression input, convert scalar metadata to arrays
287
- if modality is not None:
288
- modality = np.array([modality])
289
- if species is not None:
290
- species = np.array([species])
291
- if technology is not None:
292
- technology = np.array([technology])
293
- else:
294
- raise ValueError("Either adata or gene_expression must be provided")
295
 
296
- # Tokenize gene expression
297
- token_ids = self._tokenize_gene_expression(x)
298
- n_cells = token_ids.shape[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
- # Add special tokens for each cell
301
- special_tokens = np.zeros((n_cells, 3), dtype=np.int32) # 3 for modality, species, technology
302
- special_token_mask = np.zeros((n_cells, 3), dtype=bool) # Track which tokens are actually present
303
 
304
- if modality is not None:
305
- special_tokens[:, 0] = [self.modality_dict.get(m, self._vocabulary["[PAD]"]) for m in modality]
306
  special_token_mask[:, 0] = True
307
-
308
- if species is not None:
309
- special_tokens[:, 1] = [self.species_dict.get(s, self._vocabulary["[PAD]"]) for s in species]
310
  special_token_mask[:, 1] = True
311
-
312
- if technology is not None:
313
- special_tokens[:, 2] = [self.technology_dict.get(t, self._vocabulary["[PAD]"]) for t in technology]
314
  special_token_mask[:, 2] = True
315
-
316
  # Only keep the special tokens that are present (have True in mask)
317
  special_tokens = special_tokens[:, special_token_mask[0]]
318
 
319
  if special_tokens.size > 0:
320
  token_ids = np.concatenate([special_tokens, token_ids[:, :(self.max_length - special_tokens.shape[1])]], axis=1)
321
-
322
  # Create attention mask
323
  attention_mask = (token_ids != self._vocabulary["[PAD]"])
324
 
 
9
  import os
10
  import json
11
  from huggingface_hub import hf_hub_download
12
+ import pandas as pd
13
 
14
  # Token IDs must match exactly with the original implementation
15
  PAD_TOKEN = 0
 
237
 
238
  return tokens.astype(np.int32)
239
 
240
+ def __call__(self, data: Union[ad.AnnData, np.ndarray], **kwargs) -> Dict[str, torch.Tensor]:
241
+ """Tokenize gene expression data.
 
 
 
 
 
 
 
 
242
 
243
  Args:
244
+ data: AnnData object or numpy array of gene expression data
 
 
 
 
245
 
246
  Returns:
247
+ Dictionary with input_ids and attention_mask tensors
248
  """
249
+ if isinstance(data, ad.AnnData):
250
+ adata = data.copy()
251
+
252
+ # Align with reference model if available
253
+ if hasattr(self, '_load_reference_model'):
254
+ reference_model = self._load_reference_model()
255
+ if reference_model is not None:
256
+ # Concatenate and then remove the reference
257
+ adata = ad.concat([reference_model, adata], join='outer', axis=0)
258
+ adata = adata[1:]
259
+
260
+ # Get gene expression data
261
+ X = adata.X
262
+
263
+ # Get metadata for special tokens
264
+ modality = adata.obs.get('modality', None)
265
+ species = adata.obs.get('specie', None) # Note: using 'specie' as in the notebook
266
+ technology = adata.obs.get('assay', None) # Note: using 'assay' as in the notebook
267
+
268
+ # Use integer values directly if available
269
+ if modality is not None and pd.api.types.is_numeric_dtype(modality):
270
+ modality_tokens = modality.astype(int).tolist()
271
  else:
272
+ modality_tokens = [self.modality_dict.get(m, self._vocabulary["[PAD]"]) for m in modality] if modality is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ if species is not None and pd.api.types.is_numeric_dtype(species):
275
+ species_tokens = species.astype(int).tolist()
276
+ else:
277
+ species_tokens = [self.species_dict.get(s, self._vocabulary["[PAD]"]) for s in species] if species is not None else None
278
+
279
+ if technology is not None and pd.api.types.is_numeric_dtype(technology):
280
+ technology_tokens = technology.astype(int).tolist()
281
+ else:
282
+ technology_tokens = [self.technology_dict.get(t, self._vocabulary["[PAD]"]) for t in technology] if technology is not None else None
283
+ else:
284
+ X = data
285
+ modality_tokens = None
286
+ species_tokens = None
287
+ technology_tokens = None
288
+
289
+ # Tokenize gene expression data
290
+ token_ids = self._tokenize_gene_expression(X)
291
 
292
+ # Add special tokens if available
293
+ special_tokens = np.zeros((token_ids.shape[0], 3), dtype=np.int64)
294
+ special_token_mask = np.zeros((token_ids.shape[0], 3), dtype=bool)
295
 
296
+ if modality_tokens is not None:
297
+ special_tokens[:, 0] = modality_tokens
298
  special_token_mask[:, 0] = True
299
+
300
+ if species_tokens is not None:
301
+ special_tokens[:, 1] = species_tokens
302
  special_token_mask[:, 1] = True
303
+
304
+ if technology_tokens is not None:
305
+ special_tokens[:, 2] = technology_tokens
306
  special_token_mask[:, 2] = True
307
+
308
  # Only keep the special tokens that are present (have True in mask)
309
  special_tokens = special_tokens[:, special_token_mask[0]]
310
 
311
  if special_tokens.size > 0:
312
  token_ids = np.concatenate([special_tokens, token_ids[:, :(self.max_length - special_tokens.shape[1])]], axis=1)
313
+
314
  # Create attention mask
315
  attention_mask = (token_ids != self._vocabulary["[PAD]"])
316