""" This module provides PyTorch Dataset implementations for hierarchical VCF data """ import torch import json import pickle import logging from pathlib import Path from typing import Dict, List, Tuple, Optional, Union, Any, Callable from torch.utils.data import Dataset, DataLoader import numpy as np import pandas as pd from datasets import Dataset as HFDataset, DatasetDict from transformers import PreTrainedTokenizer from config import DataConfig, ModelConfig, ConfigManager from parser import VCFParser, MutationRecord from tokenizer import HierarchicalVCFTokenizer, HierarchicalDataCollator # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class HierarchicalVCFDataset(Dataset): def __init__(self, data_source: Union[str, Path, Dict, List], tokenizer: HierarchicalVCFTokenizer, config: Optional[DataConfig] = None, labels: Optional[Union[List, np.ndarray]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, cache_processed_data: bool = True): """ Initialize the Hierarchical VCF Dataset. Args: data_source: Path to data file, or preprocessed data dict/list tokenizer: Tokenizer for encoding mutations config: Data configuration labels: Optional labels for supervised learning transform: Optional transform to apply to samples target_transform: Optional transform to apply to labels cache_processed_data: Whether to cache processed data """ self.config = config or DataConfig() self.tokenizer = tokenizer self.labels = labels self.transform = transform self.target_transform = target_transform self.cache_processed_data = cache_processed_data # Load and process data self.raw_data = self._load_data(data_source) self.processed_data = self._process_data() # Validate data consistency self._validate_data() # Dataset statistics self.stats = self._compute_statistics() logger.info(f"Dataset initialized with {len(self.processed_data)} samples") logger.info(f"Dataset statistics: {self.stats}") def _load_data(self, data_source: Union[str, Path, Dict, List]) -> Dict[str, Any]: if isinstance(data_source, (dict, list)): # Data already loaded if isinstance(data_source, list): # Convert list to dict format return {f"sample_{i}": sample for i, sample in enumerate(data_source)} return data_source # Load from file data_path = Path(data_source) if not data_path.exists(): raise FileNotFoundError(f"Data file not found: {data_path}") try: if data_path.suffix.lower() == '.json': with open(data_path, 'r') as f: return json.load(f) elif data_path.suffix.lower() == '.pkl': with open(data_path, 'rb') as f: return pickle.load(f) elif data_path.suffix.lower() == '.vcf': # Parse VCF file directly parser = VCFParser(config=self.config) return parser.parse_vcf_file(data_path) else: raise ValueError(f"Unsupported file format: {data_path.suffix}") except Exception as e: logger.error(f"Error loading data from {data_path}: {e}") raise def _process_data(self) -> List[Dict[str, Any]]: """Raw hierarchical data into dataset format.""" processed_samples = [] for sample_id, sample_data in self.raw_data.items(): try: # Convert to standard format if needed standardized_sample = self._standardize_sample_format(sample_data) # Filter samples based on configuration if self._should_include_sample(standardized_sample): # Encode the sample encoded_sample = self.tokenizer.encode_hierarchical_sample(standardized_sample) processed_sample = { 'sample_id': sample_id, 'encoded_data': encoded_sample, 'raw_data': standardized_sample if not self.cache_processed_data else None } processed_samples.append(processed_sample) except Exception as e: logger.warning(f"Error processing sample {sample_id}: {e}") continue return processed_samples def _standardize_sample_format(self, sample_data: Dict[str, Any]) -> Dict[str, Any]: # Handle different input formats if 'mutations' in sample_data: # Format: {'mutations': [...]} return self._convert_flat_to_hierarchical(sample_data['mutations']) elif isinstance(sample_data, dict) and all( isinstance(v, dict) for v in sample_data.values() ): # Already in hierarchical format return sample_data else: # Assume it's a list of mutations return self._convert_flat_to_hierarchical(sample_data) def _convert_flat_to_hierarchical(self, mutations: List[Dict]) -> Dict[str, Any]: """Convert flat mutation list to hierarchical format.""" hierarchical = {} for mutation in mutations: # Extract hierarchical keys pathway = mutation.get('pathway', 'Unknown_Pathway') chromosome = mutation.get('chromosome', mutation.get('chrom', 'Unknown')) gene = mutation.get('gene', mutation.get('gene_id', 'Unknown_Gene')) # Initialize nested structure if pathway not in hierarchical: hierarchical[pathway] = {} if chromosome not in hierarchical[pathway]: hierarchical[pathway][chromosome] = {} if gene not in hierarchical[pathway][chromosome]: hierarchical[pathway][chromosome][gene] = [] # Add mutation hierarchical[pathway][chromosome][gene].append(mutation) return hierarchical def _should_include_sample(self, sample_data: Dict[str, Any]) -> bool: """Determine if sample should be included based on filtering criteria.""" # Count total mutations total_mutations = 0 for pathway_data in sample_data.values(): for chrom_data in pathway_data.values(): for gene_mutations in chrom_data.values(): total_mutations += len(gene_mutations) # Apply filters if total_mutations < self.config.min_mutations_per_sample: return False if total_mutations > self.config.max_mutations_per_sample: return False return True def _validate_data(self) -> None: if len(self.processed_data) == 0: raise ValueError("No valid samples found in dataset") if self.labels is not None: if len(self.labels) != len(self.processed_data): raise ValueError( f"Number of labels ({len(self.labels)}) doesn't match " f"number of samples ({len(self.processed_data)})" ) def _compute_statistics(self) -> Dict[str, Any]: """CDataset statistics.""" stats = { 'num_samples': len(self.processed_data), 'num_pathways': set(), 'num_chromosomes': set(), 'num_genes': set(), 'mutations_per_sample': [], 'genes_per_sample': [], 'pathways_per_sample': [] } for sample in self.processed_data: encoded_data = sample['encoded_data'] sample_pathways = len(encoded_data) sample_genes = 0 sample_mutations = 0 for pathway_token, chromosomes in encoded_data.items(): stats['num_pathways'].add(pathway_token) for chrom_token, genes in chromosomes.items(): stats['num_chromosomes'].add(chrom_token) for gene_token, mutations in genes.items(): stats['num_genes'].add(gene_token) sample_genes += 1 # Count mutations (assuming 'impact' field exists) if 'impact' in mutations: sample_mutations += len(mutations['impact']) stats['mutations_per_sample'].append(sample_mutations) stats['genes_per_sample'].append(sample_genes) stats['pathways_per_sample'].append(sample_pathways) # Convert sets to counts stats['unique_pathways'] = len(stats['num_pathways']) stats['unique_chromosomes'] = len(stats['num_chromosomes']) stats['unique_genes'] = len(stats['num_genes']) # Compute summary statistics if stats['mutations_per_sample']: stats['avg_mutations_per_sample'] = np.mean(stats['mutations_per_sample']) stats['std_mutations_per_sample'] = np.std(stats['mutations_per_sample']) if stats['genes_per_sample']: stats['avg_genes_per_sample'] = np.mean(stats['genes_per_sample']) stats['std_genes_per_sample'] = np.std(stats['genes_per_sample']) # Remove raw sets del stats['num_pathways'], stats['num_chromosomes'], stats['num_genes'] return stats def __len__(self) -> int: """Number of samples in the dataset.""" return len(self.processed_data) def __getitem__(self, idx: int) -> Dict[str, Any]: """Single sample from the dataset.""" if idx >= len(self.processed_data): raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}") sample = self.processed_data[idx].copy() # Apply transforms if self.transform: sample['encoded_data'] = self.transform(sample['encoded_data']) # Add label if available if self.labels is not None: label = self.labels[idx] if self.target_transform: label = self.target_transform(label) sample['label'] = label return sample def get_sample_by_id(self, sample_id: str) -> Optional[Dict[str, Any]]: for i, sample in enumerate(self.processed_data): if sample['sample_id'] == sample_id: return self.__getitem__(i) return None def get_statistics(self) -> Dict[str, Any]: return self.stats.copy() def save_dataset(self, save_path: Union[str, Path], format: str = 'pickle') -> None: """ Args: save_path: Path to save the dataset format: Save format ('pickle', 'json') """ save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) dataset_info = { 'processed_data': self.processed_data, 'labels': self.labels.tolist() if isinstance(self.labels, np.ndarray) else self.labels, 'stats': self.stats, 'config': self.config.__dict__ if hasattr(self.config, '__dict__') else None } if format.lower() == 'pickle': with open(save_path, 'wb') as f: pickle.dump(dataset_info, f) elif format.lower() == 'json': with open(save_path, 'w') as f: json.dump(dataset_info, f, indent=2, default=str) else: raise ValueError(f"Unsupported save format: {format}") logger.info(f"Dataset saved to {save_path}") @classmethod def load_dataset(cls, load_path: Union[str, Path], tokenizer: HierarchicalVCFTokenizer, format: str = 'auto') -> 'HierarchicalVCFDataset': """ Args: load_path: Path to load the dataset from tokenizer: Tokenizer instance format: Load format ('pickle', 'json', 'auto') Returns: Loaded dataset instance """ load_path = Path(load_path) if not load_path.exists(): raise FileNotFoundError(f"Dataset file not found: {load_path}") # Determine format if format == 'auto': format = 'pickle' if load_path.suffix == '.pkl' else 'json' # Load data if format.lower() == 'pickle': with open(load_path, 'rb') as f: dataset_info = pickle.load(f) elif format.lower() == 'json': with open(load_path, 'r') as f: dataset_info = json.load(f) else: raise ValueError(f"Unsupported load format: {format}") # Create dataset instance dataset = cls.__new__(cls) dataset.tokenizer = tokenizer dataset.processed_data = dataset_info['processed_data'] dataset.labels = dataset_info.get('labels') dataset.stats = dataset_info.get('stats', {}) dataset.config = dataset_info.get('config', DataConfig()) dataset.transform = None dataset.target_transform = None dataset.cache_processed_data = True return dataset class HierarchicalVCFDataModule: """ Manage train/validation/test splits of hierarchical VCF data. """ def __init__(self, data_source: Union[str, Path, Dict], tokenizer: HierarchicalVCFTokenizer, config: Optional[DataConfig] = None, labels: Optional[Union[List, np.ndarray]] = None, train_split: float = 0.8, val_split: float = 0.1, test_split: float = 0.1, stratify: bool = True, random_seed: int = 42): """ Args: data_source: Source of the data tokenizer: Tokenizer for encoding config: Data configuration labels: Labels for supervised learning train_split: Proportion for training val_split: Proportion for validation test_split: Proportion for testing stratify: Whether to stratify splits by labels random_seed: Random seed for reproducibility """ self.config = config or DataConfig() self.tokenizer = tokenizer self.train_split = train_split self.val_split = val_split self.test_split = test_split self.stratify = stratify self.random_seed = random_seed # Validate splits if abs(train_split + val_split + test_split - 1.0) > 1e-6: raise ValueError("Train, validation, and test splits must sum to 1.0") # Load full dataset self.full_dataset = HierarchicalVCFDataset( data_source=data_source, tokenizer=tokenizer, config=config, labels=labels ) # Create splits self.train_dataset, self.val_dataset, self.test_dataset = self._create_splits() logger.info(f"Data module initialized:") logger.info(f" Train: {len(self.train_dataset)} samples") logger.info(f" Validation: {len(self.val_dataset)} samples") logger.info(f" Test: {len(self.test_dataset)} samples") def _create_splits(self) -> Tuple[Dataset, Dataset, Dataset]: np.random.seed(self.random_seed) indices = np.arange(len(self.full_dataset)) if self.stratify and self.full_dataset.labels is not None: # Stratified split from sklearn.model_selection import train_test_split # First split: train vs (val + test) train_idx, temp_idx = train_test_split( indices, test_size=(self.val_split + self.test_split), stratify=[self.full_dataset.labels[i] for i in indices], random_state=self.random_seed ) # Second split: val vs test if self.test_split > 0: val_idx, test_idx = train_test_split( temp_idx, test_size=self.test_split / (self.val_split + self.test_split), stratify=[self.full_dataset.labels[i] for i in temp_idx], random_state=self.random_seed ) else: val_idx = temp_idx test_idx = np.array([]) else: # Random split np.random.shuffle(indices) train_end = int(self.train_split * len(indices)) val_end = int((self.train_split + self.val_split) * len(indices)) train_idx = indices[:train_end] val_idx = indices[train_end:val_end] test_idx = indices[val_end:] # Create subset datasets train_dataset = self._create_subset(train_idx) val_dataset = self._create_subset(val_idx) test_dataset = self._create_subset(test_idx) return train_dataset, val_dataset, test_dataset def _create_subset(self, indices: np.ndarray) -> Dataset: """Create a subset dataset from indices.""" subset_data = [self.full_dataset.processed_data[i] for i in indices] subset_labels = None if self.full_dataset.labels is not None: if isinstance(self.full_dataset.labels, np.ndarray): subset_labels = self.full_dataset.labels[indices] else: subset_labels = [self.full_dataset.labels[i] for i in indices] # Create new dataset instance dataset = HierarchicalVCFDataset.__new__(HierarchicalVCFDataset) dataset.tokenizer = self.tokenizer dataset.config = self.config dataset.processed_data = subset_data dataset.labels = subset_labels dataset.transform = None dataset.target_transform = None dataset.cache_processed_data = True dataset.stats = dataset._compute_statistics() return dataset def get_dataloaders(self, batch_size: int = 16, num_workers: int = 0, collate_fn: Optional[Callable] = None) -> Tuple[DataLoader, DataLoader, DataLoader]: """ Args: batch_size: Batch size for data loading num_workers: Number of worker processes collate_fn: Custom collate function Returns: Tuple of (train_loader, val_loader, test_loader) """ if collate_fn is None: collate_fn = HierarchicalDataCollator(self.tokenizer) train_loader = DataLoader( self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn ) val_loader = DataLoader( self.val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn ) test_loader = DataLoader( self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn ) return train_loader, val_loader, test_loader class HuggingFaceDatasetAdapter: """ Convert hierarchical VCF data to Hugging Face Dataset format. """ def __init__(self, vcf_dataset: HierarchicalVCFDataset): self.vcf_dataset = vcf_dataset def to_huggingface_dataset(self) -> DatasetDict: """ Returns: HuggingFace DatasetDict """ # Flatten hierarchical data for HF compatibility flattened_data = [] for sample in self.vcf_dataset.processed_data: sample_id = sample['sample_id'] encoded_data = sample['encoded_data'] # Convert hierarchical structure to flattened format flattened_sample = { 'sample_id': sample_id, 'pathways': list(encoded_data.keys()), 'num_pathways': len(encoded_data), 'encoded_mutations': self._flatten_mutations(encoded_data) } flattened_data.append(flattened_sample) # Add labels if available if self.vcf_dataset.labels is not None: for i, sample in enumerate(flattened_data): sample['label'] = self.vcf_dataset.labels[i] # Create HuggingFace dataset hf_dataset = HFDataset.from_list(flattened_data) return DatasetDict({'train': hf_dataset}) def _flatten_mutations(self, encoded_data: Dict) -> Dict[str, List]: """Flatten hierarchical mutations for HF compatibility.""" all_impacts = [] all_refs = [] all_alts = [] for pathway_token, chromosomes in encoded_data.items(): for chrom_token, genes in chromosomes.items(): for gene_token, mutations in genes.items(): if 'impact' in mutations: all_impacts.extend(mutations['impact']) if 'ref' in mutations: all_refs.extend(mutations['ref']) if 'alt' in mutations: all_alts.extend(mutations['alt']) return { 'impacts': all_impacts, 'refs': all_refs, 'alts': all_alts } def create_dataset_from_config(config_manager: ConfigManager, tokenizer: HierarchicalVCFTokenizer, labels: Optional[List] = None) -> HierarchicalVCFDataset: data_config = config_manager.data_config if not data_config.vcf_file_path: raise ValueError("VCF file path not specified in configuration") return HierarchicalVCFDataset( data_source=data_config.vcf_file_path, tokenizer=tokenizer, config=data_config, labels=labels ) def create_data_module_from_config(config_manager: ConfigManager, tokenizer: HierarchicalVCFTokenizer, labels: Optional[List] = None) -> HierarchicalVCFDataModule: data_config = config_manager.data_config if not data_config.vcf_file_path: raise ValueError("VCF file path not specified in configuration") return HierarchicalVCFDataModule( data_source=data_config.vcf_file_path, tokenizer=tokenizer, config=data_config, labels=labels ) # Utility functions for data preprocessing def create_synthetic_labels(dataset: HierarchicalVCFDataset, label_type: str = 'random', num_classes: int = 2) -> np.ndarray: """ Create synthetic labels for testing purposes. Args: dataset: VCF dataset label_type: Type of labels ('random', 'mutation_count_based') num_classes: Number of classes for classification Returns: Array of synthetic labels """ num_samples = len(dataset) if label_type == 'random': return np.random.randint(0, num_classes, size=num_samples) elif label_type == 'mutation_count_based': # Create labels based on mutation count thresholds mutation_counts = dataset.stats['mutations_per_sample'] threshold = np.median(mutation_counts) labels = [] for count in mutation_counts: if num_classes == 2: labels.append(1 if count > threshold else 0) else: # Divide into quantiles percentiles = np.linspace(0, 100, num_classes + 1) thresholds = np.percentile(mutation_counts, percentiles[1:-1]) label = 0 for i, t in enumerate(thresholds): if count > t: label = i + 1 else: break labels.append(label) return np.array(labels) else: raise ValueError(f"Unknown label type: {label_type}") # Example usage and testing if __name__ == "__main__": from tokenizer import create_tokenizer_from_config # Example usage config_manager = ConfigManager() config_manager.data_config.vcf_file_path = "example_data.json" # Create tokenizer tokenizer = create_tokenizer_from_config(config_manager) # Example data example_data = { 'sample1': { 'pathway1': { 'chr1': { 'gene1': [ {'impact': 'HIGH', 'reference': 'A', 'alternate': 'T'}, {'impact': 'MODERATE', 'reference': 'G', 'alternate': 'C'} ] } } }, 'sample2': { 'pathway2': { 'chr2': { 'gene2': [ {'impact': 'LOW', 'reference': 'T', 'alternate': 'A'} ] } } } } # Build tokenizer vocabulary tokenizer.build_vocabulary(example_data) # Create dataset dataset = HierarchicalVCFDataset( data_source=example_data, tokenizer=tokenizer ) # Create synthetic labels labels = create_synthetic_labels(dataset, label_type='random', num_classes=2) dataset.labels = labels # Create data module data_module = HierarchicalVCFDataModule( data_source=example_data, tokenizer=tokenizer, labels=labels, train_split=0.6, val_split=0.2, test_split=0.2 ) # Get data loaders train_loader, val_loader, test_loader = data_module.get_dataloaders(batch_size=2) # Test data loading for batch in train_loader: print(f"Batch size: {batch['batch_size']}") print(f"Sample IDs: {[s.get('sample_id', 'N/A') for s in batch['samples']]}") break print(f"Dataset statistics: {dataset.get_statistics()}")