from typing import List, Dict, Optional, Union, Any, Tuple import os from transformers import PreTrainedTokenizer from itertools import product import json class NucEL_Tokenizer(PreTrainedTokenizer): """ KMER Tokenizer for DNA sequences, inheriting from Hugging Face's PreTrainedTokenizer. Handles k-mer tokenization with support for special tokens, padding, and truncation. """ model_input_names = ["input_ids", "attention_mask"] def __init__( self, k: int = 6, model_max_length: int = 2048, pad_token: str = "[PAD]", unk_token: str = "[UNK]", sep_token: str = "[SEP]", cls_token: str = "[CLS]", mask_token: str = "[MASK]", bos_token: str = "[BOS]", eos_token: str = "[EOS]", num_reserved_tokens: int = 16, **kwargs ): """Initialize the KMER tokenizer.""" self.k = k self.nucleotides = ['A', 'C', 'G', 'T'] self.num_reserved_tokens = num_reserved_tokens # Define special tokens self.special_tokens = { "pad_token": pad_token, "unk_token": unk_token, "sep_token": sep_token, "cls_token": cls_token, "mask_token": mask_token, "bos_token": bos_token, "eos_token": eos_token, } # Build vocabulary (includes special tokens, nucleotides, and k-mers) self._init_vocabulary() # Now initialize the parent class. super().__init__( model_max_length=model_max_length, pad_token=pad_token, unk_token=unk_token, sep_token=sep_token, cls_token=cls_token, mask_token=mask_token, bos_token=bos_token, eos_token=eos_token, **kwargs ) def _init_vocabulary(self): """Initialize the vocabulary with special tokens, nucleotides, and k-mers.""" # Get special tokens in a specific order special_tokens = [ self.special_tokens["pad_token"], self.special_tokens["unk_token"], self.special_tokens["cls_token"], self.special_tokens["sep_token"], self.special_tokens["mask_token"], self.special_tokens["bos_token"], self.special_tokens["eos_token"] ] # Add individual nucleotides nucleotides = self.nucleotides # Generate all possible k-mers kmers = [''.join(p) for p in product(self.nucleotides, repeat=self.k)] # Add reserved tokens for future use reserved_tokens = [f"[RESERVED_{i}]" for i in range(self.num_reserved_tokens)] # Combine all tokens in a specific order all_tokens = special_tokens + nucleotides + kmers + reserved_tokens # Create vocabulary: token -> index self.vocab = {} for idx, token in enumerate(all_tokens): self.vocab[token] = idx # Create reverse mapping: index -> token self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()} @property def vocab_size(self) -> int: """Return the size of vocabulary.""" return len(self.vocab) def get_vocab(self) -> Dict[str, int]: """Return the vocabulary dictionary.""" return self.vocab.copy() def _tokenize(self, text: str) -> List[str]: """ Tokenize a DNA sequence into k-mers and individual nucleotides. Args: text: DNA sequence to tokenize Returns: List of tokens. """ text = text.upper().strip() tokens = [self.cls_token] i = 0 while i < len(text): # Try to get a k-mer if i <= len(text) - self.k: kmer = text[i:i+self.k] if kmer in self.vocab: tokens.append(kmer) i += self.k continue # Fallback: tokenize a single nucleotide if i < len(text): nucleotide = text[i] if nucleotide in self.nucleotides: tokens.append(nucleotide) else: tokens.append(self.unk_token) i += 1 return tokens def _convert_token_to_id(self, token: str) -> int: """Convert a token to its ID in the vocabulary.""" return self.vocab.get(token, self.vocab[self.unk_token]) def _convert_id_to_token(self, index: int) -> str: """Convert an ID to its token in the vocabulary.""" return self.ids_to_tokens.get(index, self.unk_token) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: """Save the tokenizer vocabulary to a directory.""" if not filename_prefix: filename_prefix = "vocab" vocab_file = os.path.join(save_directory, f"{filename_prefix}.json") with open(vocab_file, 'w', encoding='utf-8') as f: json.dump(self.vocab, f, ensure_ascii=False, indent=2) return (vocab_file,) def save_pretrained(self, save_directory: str, legacy_format: bool = True, filename_prefix: Optional[str] = None, **kwargs): """ Save the tokenizer configuration and vocabulary. """ # Save the vocabulary vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix) # Save the config config = { 'k': self.k, 'model_max_length': self.model_max_length, 'padding_side': self.padding_side, 'truncation_side': self.truncation_side, 'special_tokens': { 'pad_token': self.pad_token, 'unk_token': self.unk_token, 'sep_token': self.sep_token, 'cls_token': self.cls_token, 'mask_token': self.mask_token, 'bos_token': self.bos_token, 'eos_token': self.eos_token, } } super().save_pretrained(save_directory, config=config, legacy_format=legacy_format, **kwargs) return vocab_files @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs): """ Load a tokenizer from a pretrained model. """ # Load the tokenizer configuration config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json") with open(config_file, 'r', encoding='utf-8') as f: config = json.load(f) # Load the vocabulary vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json") with open(vocab_file, 'r', encoding='utf-8') as f: vocab = json.load(f) # Extract k from config (add it to your tokenizer_config.json if not present) k = config.get('k', 6) # Create tokenizer instance - tokens are at top level in tokenizer_config.json tokenizer = cls( k=k, model_max_length=config.get('model_max_length', 2048), pad_token=config.get('pad_token', '[PAD]'), unk_token=config.get('unk_token', '[UNK]'), sep_token=config.get('sep_token', '[SEP]'), cls_token=config.get('cls_token', '[CLS]'), mask_token=config.get('mask_token', '[MASK]'), bos_token=config.get('bos_token', '[BOS]'), eos_token=config.get('eos_token', '[EOS]'), **kwargs ) # Override the vocabulary with the saved one tokenizer.vocab = vocab tokenizer.ids_to_tokens = {idx: token for token, idx in vocab.items()} return tokenizer