aletlvl commited on
Commit
6ff8934
·
verified ·
1 Parent(s): 4d79ef6

Tokenization fixed

Browse files
Files changed (1) hide show
  1. tokenization_nicheformer.py +14 -12
tokenization_nicheformer.py CHANGED
@@ -257,27 +257,27 @@ class NicheformerTokenizer(PreTrainedTokenizer):
257
  # Avoid division by zero
258
  safe_mean = np.maximum(self.technology_mean, 1e-6)
259
  x = x / safe_mean
260
-
261
  # Apply log1p transformation
262
  x = np.log1p(x)
263
-
264
  # Convert to tokens
265
  tokens = _sub_tokenize_data(x, self.max_length, self.aux_tokens)
266
-
267
  return tokens.astype(np.int32)
268
-
269
  def __call__(self, data: Union[ad.AnnData, np.ndarray], **kwargs) -> Dict[str, torch.Tensor]:
270
  """Tokenize gene expression data.
271
-
272
  Args:
273
  data: AnnData object or numpy array of gene expression data
274
-
275
  Returns:
276
  Dictionary with input_ids and attention_mask tensors
277
  """
278
  if isinstance(data, ad.AnnData):
279
  adata = data.copy()
280
-
281
  # Align with reference model if available
282
  if hasattr(self, '_load_reference_model'):
283
  reference_model = self._load_reference_model()
@@ -285,15 +285,15 @@ class NicheformerTokenizer(PreTrainedTokenizer):
285
  # Concatenate and then remove the reference
286
  adata = ad.concat([reference_model, adata], join='outer', axis=0)
287
  adata = adata[1:]
288
-
289
  # Get gene expression data
290
  X = adata.X
291
-
292
  # Get metadata for special tokens
293
  modality = adata.obs.get('modality', None)
294
  species = adata.obs.get('specie', None)
295
  technology = adata.obs.get('assay', None)
296
-
297
  print(f"Modality: {modality}")
298
  print(f"Species: {species}")
299
  print(f"Technology: {technology}")
@@ -302,16 +302,18 @@ class NicheformerTokenizer(PreTrainedTokenizer):
302
  modality_tokens = modality.astype(int).tolist()
303
  else:
304
  modality_tokens = [self.modality_dict.get(m, self._vocabulary["[PAD]"]) for m in modality] if modality is not None else None
305
-
306
  if species is not None and pd.api.types.is_numeric_dtype(species):
307
  species_tokens = species.astype(int).tolist()
308
  else:
309
  species_tokens = [self.species_dict.get(s, self._vocabulary["[PAD]"]) for s in species] if species is not None else None
310
-
311
  if technology is not None and pd.api.types.is_numeric_dtype(technology):
312
  technology_tokens = technology.astype(int).tolist()
 
313
  else:
314
  technology_tokens = [self.technology_dict.get(t, self._vocabulary["[PAD]"]) for t in technology] if technology is not None else None
 
315
  else:
316
  X = data
317
  modality_tokens = None
 
257
  # Avoid division by zero
258
  safe_mean = np.maximum(self.technology_mean, 1e-6)
259
  x = x / safe_mean
260
+
261
  # Apply log1p transformation
262
  x = np.log1p(x)
263
+
264
  # Convert to tokens
265
  tokens = _sub_tokenize_data(x, self.max_length, self.aux_tokens)
266
+
267
  return tokens.astype(np.int32)
268
+
269
  def __call__(self, data: Union[ad.AnnData, np.ndarray], **kwargs) -> Dict[str, torch.Tensor]:
270
  """Tokenize gene expression data.
271
+
272
  Args:
273
  data: AnnData object or numpy array of gene expression data
274
+
275
  Returns:
276
  Dictionary with input_ids and attention_mask tensors
277
  """
278
  if isinstance(data, ad.AnnData):
279
  adata = data.copy()
280
+
281
  # Align with reference model if available
282
  if hasattr(self, '_load_reference_model'):
283
  reference_model = self._load_reference_model()
 
285
  # Concatenate and then remove the reference
286
  adata = ad.concat([reference_model, adata], join='outer', axis=0)
287
  adata = adata[1:]
288
+
289
  # Get gene expression data
290
  X = adata.X
291
+
292
  # Get metadata for special tokens
293
  modality = adata.obs.get('modality', None)
294
  species = adata.obs.get('specie', None)
295
  technology = adata.obs.get('assay', None)
296
+
297
  print(f"Modality: {modality}")
298
  print(f"Species: {species}")
299
  print(f"Technology: {technology}")
 
302
  modality_tokens = modality.astype(int).tolist()
303
  else:
304
  modality_tokens = [self.modality_dict.get(m, self._vocabulary["[PAD]"]) for m in modality] if modality is not None else None
305
+
306
  if species is not None and pd.api.types.is_numeric_dtype(species):
307
  species_tokens = species.astype(int).tolist()
308
  else:
309
  species_tokens = [self.species_dict.get(s, self._vocabulary["[PAD]"]) for s in species] if species is not None else None
310
+
311
  if technology is not None and pd.api.types.is_numeric_dtype(technology):
312
  technology_tokens = technology.astype(int).tolist()
313
+ print(f"Technology tokens: {technology_tokens}")
314
  else:
315
  technology_tokens = [self.technology_dict.get(t, self._vocabulary["[PAD]"]) for t in technology] if technology is not None else None
316
+ print(f"Technology tokens resort: {technology_tokens}")
317
  else:
318
  X = data
319
  modality_tokens = None