Tokenization fixed
Browse files- tokenization_nicheformer.py +58 -66
tokenization_nicheformer.py
CHANGED
|
@@ -9,6 +9,7 @@ import numba
|
|
| 9 |
import os
|
| 10 |
import json
|
| 11 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 12 |
|
| 13 |
# Token IDs must match exactly with the original implementation
|
| 14 |
PAD_TOKEN = 0
|
|
@@ -236,89 +237,80 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 236 |
|
| 237 |
return tokens.astype(np.int32)
|
| 238 |
|
| 239 |
-
def __call__(
|
| 240 |
-
|
| 241 |
-
adata: Optional[ad.AnnData] = None,
|
| 242 |
-
gene_expression: Optional[Union[np.ndarray, List[float]]] = None,
|
| 243 |
-
modality: Optional[str] = None,
|
| 244 |
-
species: Optional[str] = None,
|
| 245 |
-
technology: Optional[str] = None,
|
| 246 |
-
**kwargs
|
| 247 |
-
) -> Dict[str, torch.Tensor]:
|
| 248 |
-
"""Convert inputs to model inputs.
|
| 249 |
|
| 250 |
Args:
|
| 251 |
-
|
| 252 |
-
gene_expression: Gene expression matrix if not using AnnData
|
| 253 |
-
modality: Modality type
|
| 254 |
-
species: Species type
|
| 255 |
-
technology: Technology/assay type
|
| 256 |
|
| 257 |
Returns:
|
| 258 |
-
Dictionary with
|
| 259 |
"""
|
| 260 |
-
if
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
else:
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
# Get metadata for each cell if not provided
|
| 275 |
-
if modality is None and 'modality' in adata.obs:
|
| 276 |
-
modality = adata.obs['modality'].values
|
| 277 |
-
if species is None and 'specie' in adata.obs:
|
| 278 |
-
species = adata.obs['specie'].values
|
| 279 |
-
if technology is None and 'assay' in adata.obs:
|
| 280 |
-
technology = adata.obs['assay'].values
|
| 281 |
-
|
| 282 |
-
elif gene_expression is not None:
|
| 283 |
-
x = np.array(gene_expression)
|
| 284 |
-
if len(x.shape) == 1:
|
| 285 |
-
x = x.reshape(1, -1)
|
| 286 |
-
# For single gene expression input, convert scalar metadata to arrays
|
| 287 |
-
if modality is not None:
|
| 288 |
-
modality = np.array([modality])
|
| 289 |
-
if species is not None:
|
| 290 |
-
species = np.array([species])
|
| 291 |
-
if technology is not None:
|
| 292 |
-
technology = np.array([technology])
|
| 293 |
-
else:
|
| 294 |
-
raise ValueError("Either adata or gene_expression must be provided")
|
| 295 |
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
-
# Add special tokens
|
| 301 |
-
special_tokens = np.zeros((
|
| 302 |
-
special_token_mask = np.zeros((
|
| 303 |
|
| 304 |
-
if
|
| 305 |
-
special_tokens[:, 0] =
|
| 306 |
special_token_mask[:, 0] = True
|
| 307 |
-
|
| 308 |
-
if
|
| 309 |
-
special_tokens[:, 1] =
|
| 310 |
special_token_mask[:, 1] = True
|
| 311 |
-
|
| 312 |
-
if
|
| 313 |
-
special_tokens[:, 2] =
|
| 314 |
special_token_mask[:, 2] = True
|
| 315 |
-
|
| 316 |
# Only keep the special tokens that are present (have True in mask)
|
| 317 |
special_tokens = special_tokens[:, special_token_mask[0]]
|
| 318 |
|
| 319 |
if special_tokens.size > 0:
|
| 320 |
token_ids = np.concatenate([special_tokens, token_ids[:, :(self.max_length - special_tokens.shape[1])]], axis=1)
|
| 321 |
-
|
| 322 |
# Create attention mask
|
| 323 |
attention_mask = (token_ids != self._vocabulary["[PAD]"])
|
| 324 |
|
|
|
|
| 9 |
import os
|
| 10 |
import json
|
| 11 |
from huggingface_hub import hf_hub_download
|
| 12 |
+
import pandas as pd
|
| 13 |
|
| 14 |
# Token IDs must match exactly with the original implementation
|
| 15 |
PAD_TOKEN = 0
|
|
|
|
| 237 |
|
| 238 |
return tokens.astype(np.int32)
|
| 239 |
|
| 240 |
+
def __call__(self, data: Union[ad.AnnData, np.ndarray], **kwargs) -> Dict[str, torch.Tensor]:
|
| 241 |
+
"""Tokenize gene expression data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
Args:
|
| 244 |
+
data: AnnData object or numpy array of gene expression data
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
Returns:
|
| 247 |
+
Dictionary with input_ids and attention_mask tensors
|
| 248 |
"""
|
| 249 |
+
if isinstance(data, ad.AnnData):
|
| 250 |
+
adata = data.copy()
|
| 251 |
+
|
| 252 |
+
# Align with reference model if available
|
| 253 |
+
if hasattr(self, '_load_reference_model'):
|
| 254 |
+
reference_model = self._load_reference_model()
|
| 255 |
+
if reference_model is not None:
|
| 256 |
+
# Concatenate and then remove the reference
|
| 257 |
+
adata = ad.concat([reference_model, adata], join='outer', axis=0)
|
| 258 |
+
adata = adata[1:]
|
| 259 |
+
|
| 260 |
+
# Get gene expression data
|
| 261 |
+
X = adata.X
|
| 262 |
+
|
| 263 |
+
# Get metadata for special tokens
|
| 264 |
+
modality = adata.obs.get('modality', None)
|
| 265 |
+
species = adata.obs.get('specie', None) # Note: using 'specie' as in the notebook
|
| 266 |
+
technology = adata.obs.get('assay', None) # Note: using 'assay' as in the notebook
|
| 267 |
+
|
| 268 |
+
# Use integer values directly if available
|
| 269 |
+
if modality is not None and pd.api.types.is_numeric_dtype(modality):
|
| 270 |
+
modality_tokens = modality.astype(int).tolist()
|
| 271 |
else:
|
| 272 |
+
modality_tokens = [self.modality_dict.get(m, self._vocabulary["[PAD]"]) for m in modality] if modality is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
+
if species is not None and pd.api.types.is_numeric_dtype(species):
|
| 275 |
+
species_tokens = species.astype(int).tolist()
|
| 276 |
+
else:
|
| 277 |
+
species_tokens = [self.species_dict.get(s, self._vocabulary["[PAD]"]) for s in species] if species is not None else None
|
| 278 |
+
|
| 279 |
+
if technology is not None and pd.api.types.is_numeric_dtype(technology):
|
| 280 |
+
technology_tokens = technology.astype(int).tolist()
|
| 281 |
+
else:
|
| 282 |
+
technology_tokens = [self.technology_dict.get(t, self._vocabulary["[PAD]"]) for t in technology] if technology is not None else None
|
| 283 |
+
else:
|
| 284 |
+
X = data
|
| 285 |
+
modality_tokens = None
|
| 286 |
+
species_tokens = None
|
| 287 |
+
technology_tokens = None
|
| 288 |
+
|
| 289 |
+
# Tokenize gene expression data
|
| 290 |
+
token_ids = self._tokenize_gene_expression(X)
|
| 291 |
|
| 292 |
+
# Add special tokens if available
|
| 293 |
+
special_tokens = np.zeros((token_ids.shape[0], 3), dtype=np.int64)
|
| 294 |
+
special_token_mask = np.zeros((token_ids.shape[0], 3), dtype=bool)
|
| 295 |
|
| 296 |
+
if modality_tokens is not None:
|
| 297 |
+
special_tokens[:, 0] = modality_tokens
|
| 298 |
special_token_mask[:, 0] = True
|
| 299 |
+
|
| 300 |
+
if species_tokens is not None:
|
| 301 |
+
special_tokens[:, 1] = species_tokens
|
| 302 |
special_token_mask[:, 1] = True
|
| 303 |
+
|
| 304 |
+
if technology_tokens is not None:
|
| 305 |
+
special_tokens[:, 2] = technology_tokens
|
| 306 |
special_token_mask[:, 2] = True
|
| 307 |
+
|
| 308 |
# Only keep the special tokens that are present (have True in mask)
|
| 309 |
special_tokens = special_tokens[:, special_token_mask[0]]
|
| 310 |
|
| 311 |
if special_tokens.size > 0:
|
| 312 |
token_ids = np.concatenate([special_tokens, token_ids[:, :(self.max_length - special_tokens.shape[1])]], axis=1)
|
| 313 |
+
|
| 314 |
# Create attention mask
|
| 315 |
attention_mask = (token_ids != self._vocabulary["[PAD]"])
|
| 316 |
|