""" Tokenization for VCF data with support for hierarchical structures """ import json import pickle import logging from pathlib import Path from collections import defaultdict, Counter from typing import Dict, List, Tuple, Optional, Union, Any import numpy as np from transformers import PreTrainedTokenizer from transformers.tokenization_utils import AddedToken from config import DataConfig, ConfigManager from parser import MutationRecord # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class HierarchicalVCFTokenizer(PreTrainedTokenizer): vocab_files_names = { "vocab_file": "vocab.json", "mutation_vocab_file": "mutation_vocab.json" } def __init__(self, vocab_file: Optional[str] = None, mutation_vocab_file: Optional[str] = None, config: Optional[DataConfig] = None, **kwargs): # Initialize special tokens self.config = config or DataConfig() # Set up special tokens special_tokens = self.config.special_tokens pad_token = special_tokens.get("pad_token", "[PAD]") unk_token = special_tokens.get("unk_token", "[UNK]") sep_token = special_tokens.get("sep_token", "[SEP]") cls_token = special_tokens.get("cls_token", "[CLS]") super().__init__( pad_token=pad_token, unk_token=unk_token, sep_token=sep_token, cls_token=cls_token, **kwargs ) # Initialize vocabularies for different mutation fields self.mutation_fields = ['impact', 'ref', 'alt', 'chromosome', 'pathway', 'gene'] self.field_vocabs = {} # Initialize vocabularies self._initialize_vocabularies() # Load existing vocabularies if provided if vocab_file and Path(vocab_file).exists(): self.load_vocabulary(vocab_file) if mutation_vocab_file and Path(mutation_vocab_file).exists(): self.load_mutation_vocabulary(mutation_vocab_file) # Statistics self.tokenization_stats = { 'total_samples': 0, 'total_mutations': 0, 'vocab_sizes': {} } def _initialize_vocabularies(self) -> None: for field in self.mutation_fields: self.field_vocabs[field] = { self.pad_token: 0, self.unk_token: 1, self.sep_token: 2, self.cls_token: 3 } # Add common genomic tokens self._add_common_genomic_tokens() def _add_common_genomic_tokens(self) -> None: """To be made scalable and dynamic""" # Common impact values common_impacts = ["HIGH", "MODERATE", "LOW", "MODIFIER"] for impact in common_impacts: if impact not in self.field_vocabs['impact']: self.field_vocabs['impact'][impact] = len(self.field_vocabs['impact']) # Common nucleotides nucleotides = ["A", "T", "G", "C", "N", "-"] for nt in nucleotides: for field in ['ref', 'alt']: if nt not in self.field_vocabs[field]: self.field_vocabs[field][nt] = len(self.field_vocabs[field]) # Common chromosomes chromosomes = [str(i) for i in range(1, 23)] + ["X", "Y", "MT"] for chrom in chromosomes: if chrom not in self.field_vocabs['chromosome']: self.field_vocabs['chromosome'][chrom] = len(self.field_vocabs['chromosome']) def build_vocabulary(self, hierarchical_data: Dict[str, Any]) -> None: """ Args: hierarchical_data: Parsed VCF data structure """ logger.info("Building vocabularies from hierarchical data...") vocab_counters = {field: Counter() for field in self.mutation_fields} for sample_id, pathways in hierarchical_data.items(): for pathway_id, chromosomes in pathways.items(): # Count pathway occurrences vocab_counters['pathway'][pathway_id] += 1 for chrom_id, genes in chromosomes.items(): # Count chromosome occurrences vocab_counters['chromosome'][chrom_id] += 1 for gene_id, mutations in genes.items(): # Count gene occurrences vocab_counters['gene'][gene_id] += 1 for mutation in mutations: if isinstance(mutation, MutationRecord): # Count mutation field values vocab_counters['impact'][mutation.impact] += 1 vocab_counters['ref'][mutation.reference] += 1 vocab_counters['alt'][mutation.alternate] += 1 elif isinstance(mutation, dict): # Handle dictionary format vocab_counters['impact'][mutation.get('impact', self.unk_token)] += 1 vocab_counters['ref'][mutation.get('reference', self.unk_token)] += 1 vocab_counters['alt'][mutation.get('alternate', self.unk_token)] += 1 # Build vocabularies from counters for field, counter in vocab_counters.items(): for token, count in counter.most_common(): if token and token not in self.field_vocabs[field]: self.field_vocabs[field][token] = len(self.field_vocabs[field]) # Update statistics self.tokenization_stats['vocab_sizes'] = { field: len(vocab) for field, vocab in self.field_vocabs.items() } logger.info(f"Vocabulary sizes: {self.tokenization_stats['vocab_sizes']}") def encode_hierarchical_sample(self, sample_data: Dict[str, Any]) -> Dict[str, Any]: """ Encode a single hierarchical sample into tokenized format. Args: sample_data: Single sample from hierarchical data Returns: Encoded sample with tokenized values """ encoded_sample = {} for pathway_id, chromosomes in sample_data.items(): # Tokenize pathway ID pathway_token = self.field_vocabs['pathway'].get( pathway_id, self.field_vocabs['pathway'][self.unk_token] ) encoded_sample[pathway_token] = {} for chrom_id, genes in chromosomes.items(): # Tokenize chromosome ID chrom_token = self.field_vocabs['chromosome'].get( chrom_id, self.field_vocabs['chromosome'][self.unk_token] ) encoded_sample[pathway_token][chrom_token] = {} for gene_id, mutations in genes.items(): # Tokenize gene ID gene_token = self.field_vocabs['gene'].get( gene_id, self.field_vocabs['gene'][self.unk_token] ) # Encode mutations encoded_mutations = self._encode_mutations(mutations) encoded_sample[pathway_token][chrom_token][gene_token] = encoded_mutations return encoded_sample def _encode_mutations(self, mutations: List[Union[MutationRecord, Dict]]) -> Dict[str, List[int]]: encoded_mutations = { 'impact': [], 'ref': [], 'alt': [] } for mutation in mutations: if isinstance(mutation, MutationRecord): impact = mutation.impact ref = mutation.reference alt = mutation.alternate elif isinstance(mutation, dict): impact = mutation.get('impact', self.unk_token) ref = mutation.get('reference', self.unk_token) alt = mutation.get('alternate', self.unk_token) else: continue # Tokenize each field encoded_mutations['impact'].append( self.field_vocabs['impact'].get(impact, self.field_vocabs['impact'][self.unk_token]) ) encoded_mutations['ref'].append( self.field_vocabs['ref'].get(ref, self.field_vocabs['ref'][self.unk_token]) ) encoded_mutations['alt'].append( self.field_vocabs['alt'].get(alt, self.field_vocabs['alt'][self.unk_token]) ) return encoded_mutations def encode_batch(self, batch_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Encode a batch of hierarchical samples. Args: batch_data: List of sample dictionaries Returns: List of encoded samples """ encoded_batch = [] for sample_data in batch_data: encoded_sample = self.encode_hierarchical_sample(sample_data) encoded_batch.append(encoded_sample) self.tokenization_stats['total_samples'] += len(batch_data) return encoded_batch def decode_tokens(self, field: str, token_ids: List[int]) -> List[str]: """ Decode token IDs back to original values. Args: field: Field name ('impact', 'ref', 'alt', etc.) token_ids: List of token IDs Returns: List of decoded tokens """ if field not in self.field_vocabs: raise ValueError(f"Unknown field: {field}") id_to_token = {v: k for k, v in self.field_vocabs[field].items()} return [id_to_token.get(token_id, self.unk_token) for token_id in token_ids] def get_vocab_size(self, field: str) -> int: """Get vocabulary size for a specific field.""" if field not in self.field_vocabs: raise ValueError(f"Unknown field: {field}") return len(self.field_vocabs[field]) def get_all_vocab_sizes(self) -> Dict[str, int]: """Get vocabulary sizes for all fields.""" return {field: len(vocab) for field, vocab in self.field_vocabs.items()} def save_vocabulary(self, save_directory: Union[str, Path], filename_prefix: Optional[str] = None) -> Tuple[str, ...]: """ Args: save_directory: Directory to save vocabularies filename_prefix: Optional prefix for filenames Returns: Tuple of saved file paths """ save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) prefix = f"{filename_prefix}_" if filename_prefix else "" # Save mutation vocabularies mutation_vocab_file = save_directory / f"{prefix}mutation_vocab.json" with open(mutation_vocab_file, 'w') as f: json.dump(self.field_vocabs, f, indent=2) # Save tokenizer configuration config_file = save_directory / f"{prefix}tokenizer_config.json" config_data = { 'tokenizer_class': self.__class__.__name__, 'special_tokens': { 'pad_token': self.pad_token, 'unk_token': self.unk_token, 'sep_token': self.sep_token, 'cls_token': self.cls_token }, 'vocab_sizes': self.get_all_vocab_sizes(), 'mutation_fields': self.mutation_fields } with open(config_file, 'w') as f: json.dump(config_data, f, indent=2) logger.info(f"Vocabularies saved to {save_directory}") return str(mutation_vocab_file), str(config_file) def load_vocabulary(self, vocab_file: Union[str, Path]) -> None: vocab_file = Path(vocab_file) if not vocab_file.exists(): raise FileNotFoundError(f"Vocabulary file not found: {vocab_file}") with open(vocab_file, 'r') as f: vocab_data = json.load(f) # Update vocabularies for field, vocab in vocab_data.items(): if field in self.mutation_fields: self.field_vocabs[field] = vocab logger.info(f"Vocabularies loaded from {vocab_file}") def load_mutation_vocabulary(self, mutation_vocab_file: Union[str, Path]) -> None: """Load mutation-specific vocabularies from file.""" self.load_vocabulary(mutation_vocab_file) def create_padding_masks(self, encoded_sample: Dict[str, Any], max_lengths: Dict[str, int]) -> Dict[str, Any]: """ Create padding masks for hierarchical data. Args: encoded_sample: Encoded sample data max_lengths: Maximum lengths for each level Returns: Sample with padding masks """ masked_sample = {} for pathway_token, chromosomes in encoded_sample.items(): masked_sample[pathway_token] = {} for chrom_token, genes in chromosomes.items(): masked_sample[pathway_token][chrom_token] = {} for gene_token, mutations in genes.items(): masked_mutations = {} for field, token_list in mutations.items(): max_len = max_lengths.get(f'mutations_{field}', 100) # Pad or truncate if len(token_list) < max_len: padded_list = token_list + [self.field_vocabs[field][self.pad_token]] * (max_len - len(token_list)) mask = [1] * len(token_list) + [0] * (max_len - len(token_list)) else: padded_list = token_list[:max_len] mask = [1] * max_len masked_mutations[field] = { 'tokens': padded_list, 'mask': mask } masked_sample[pathway_token][chrom_token][gene_token] = masked_mutations return masked_sample def get_tokenization_statistics(self) -> Dict[str, Any]: stats = self.tokenization_stats.copy() stats['vocab_sizes'] = self.get_all_vocab_sizes() return stats # Hugging Face compatibility methods @property def vocab_size(self) -> int: return sum(len(vocab) for vocab in self.field_vocabs.values()) def get_vocab(self) -> Dict[str, int]: combined_vocab = {} offset = 0 for field, vocab in self.field_vocabs.items(): for token, idx in vocab.items(): combined_vocab[f"{field}:{token}"] = idx + offset offset += len(vocab) return combined_vocab def _tokenize(self, text: str) -> List[str]: # This is a simplified implementation for compatibility # In practice, hierarchical data should be processed differently return text.split() def _convert_token_to_id(self, token: str) -> int: # Parse field:token format if ':' in token: field, actual_token = token.split(':', 1) if field in self.field_vocabs: return self.field_vocabs[field].get(actual_token, self.field_vocabs[field][self.unk_token]) return self.field_vocabs.get('impact', {}).get(self.unk_token, 1) def _convert_id_to_token(self, index: int) -> str: # This is a simplified reverse lookup for field, vocab in self.field_vocabs.items(): id_to_token = {v: k for k, v in vocab.items()} if index in id_to_token: return f"{field}:{id_to_token[index]}" return self.unk_token class HierarchicalDataCollator: def __init__(self, tokenizer: HierarchicalVCFTokenizer, max_lengths: Optional[Dict[str, int]] = None): self.tokenizer = tokenizer self.max_lengths = max_lengths or { 'mutations_impact': 50, 'mutations_ref': 50, 'mutations_alt': 50, 'genes_per_chromosome': 100, 'chromosomes_per_pathway': 25, 'pathways_per_sample': 50 } def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: """ Collate batch of hierarchical samples. Args: batch: List of encoded hierarchical samples Returns: Collated batch ready for model input """ collated_batch = { 'samples': [], 'batch_size': len(batch), 'metadata': { 'num_pathways': [], 'num_chromosomes': [], 'num_genes': [], 'num_mutations': [] } } for sample in batch: # Create padding masks masked_sample = self.tokenizer.create_padding_masks(sample, self.max_lengths) collated_batch['samples'].append(masked_sample) # Collect metadata num_pathways = len(sample) num_chromosomes = sum(len(chroms) for chroms in sample.values()) num_genes = sum( len(genes) for chroms in sample.values() for genes in chroms.values() ) num_mutations = sum( len(mutations.get('impact', [])) for chroms in sample.values() for genes in chroms.values() for mutations in genes.values() ) collated_batch['metadata']['num_pathways'].append(num_pathways) collated_batch['metadata']['num_chromosomes'].append(num_chromosomes) collated_batch['metadata']['num_genes'].append(num_genes) collated_batch['metadata']['num_mutations'].append(num_mutations) return collated_batch def create_tokenizer_from_config(config_manager: ConfigManager) -> HierarchicalVCFTokenizer: """Create tokenizer from configuration manager.""" return HierarchicalVCFTokenizer(config=config_manager.data_config) # Example usage and testing if __name__ == "__main__": # Example usage config_manager = ConfigManager() tokenizer = create_tokenizer_from_config(config_manager) # Example hierarchical data structure example_data = { 'sample1': { 'pathway1': { 'chr1': { 'gene1': [ { 'impact': 'HIGH', 'reference': 'A', 'alternate': 'T' } ] } } } } # Build vocabulary tokenizer.build_vocabulary({'sample1': example_data['sample1']}) # Encode sample encoded = tokenizer.encode_hierarchical_sample(example_data['sample1']) print(f"Encoded sample: {encoded}") # Save vocabulary tokenizer.save_vocabulary("./tokenizer_files") print(f"Tokenization statistics: {tokenizer.get_tokenization_statistics()}")