|
|
""" |
|
|
This module provides PyTorch Dataset implementations for hierarchical VCF data |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import json |
|
|
import pickle |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Optional, Union, Any, Callable |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
|
|
|
from datasets import Dataset as HFDataset, DatasetDict |
|
|
from transformers import PreTrainedTokenizer |
|
|
|
|
|
from config import DataConfig, ModelConfig, ConfigManager |
|
|
from parser import VCFParser, MutationRecord |
|
|
from tokenizer import HierarchicalVCFTokenizer, HierarchicalDataCollator |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class HierarchicalVCFDataset(Dataset): |
|
|
|
|
|
def __init__(self, |
|
|
data_source: Union[str, Path, Dict, List], |
|
|
tokenizer: HierarchicalVCFTokenizer, |
|
|
config: Optional[DataConfig] = None, |
|
|
labels: Optional[Union[List, np.ndarray]] = None, |
|
|
transform: Optional[Callable] = None, |
|
|
target_transform: Optional[Callable] = None, |
|
|
cache_processed_data: bool = True): |
|
|
""" |
|
|
Initialize the Hierarchical VCF Dataset. |
|
|
Args: |
|
|
data_source: Path to data file, or preprocessed data dict/list |
|
|
tokenizer: Tokenizer for encoding mutations |
|
|
config: Data configuration |
|
|
labels: Optional labels for supervised learning |
|
|
transform: Optional transform to apply to samples |
|
|
target_transform: Optional transform to apply to labels |
|
|
cache_processed_data: Whether to cache processed data |
|
|
""" |
|
|
|
|
|
self.config = config or DataConfig() |
|
|
self.tokenizer = tokenizer |
|
|
self.labels = labels |
|
|
self.transform = transform |
|
|
self.target_transform = target_transform |
|
|
self.cache_processed_data = cache_processed_data |
|
|
|
|
|
|
|
|
self.raw_data = self._load_data(data_source) |
|
|
self.processed_data = self._process_data() |
|
|
|
|
|
|
|
|
self._validate_data() |
|
|
|
|
|
|
|
|
self.stats = self._compute_statistics() |
|
|
|
|
|
logger.info(f"Dataset initialized with {len(self.processed_data)} samples") |
|
|
logger.info(f"Dataset statistics: {self.stats}") |
|
|
|
|
|
def _load_data(self, data_source: Union[str, Path, Dict, List]) -> Dict[str, Any]: |
|
|
|
|
|
if isinstance(data_source, (dict, list)): |
|
|
|
|
|
if isinstance(data_source, list): |
|
|
|
|
|
return {f"sample_{i}": sample for i, sample in enumerate(data_source)} |
|
|
return data_source |
|
|
|
|
|
|
|
|
data_path = Path(data_source) |
|
|
|
|
|
if not data_path.exists(): |
|
|
raise FileNotFoundError(f"Data file not found: {data_path}") |
|
|
|
|
|
try: |
|
|
if data_path.suffix.lower() == '.json': |
|
|
with open(data_path, 'r') as f: |
|
|
return json.load(f) |
|
|
|
|
|
elif data_path.suffix.lower() == '.pkl': |
|
|
with open(data_path, 'rb') as f: |
|
|
return pickle.load(f) |
|
|
|
|
|
elif data_path.suffix.lower() == '.vcf': |
|
|
|
|
|
parser = VCFParser(config=self.config) |
|
|
return parser.parse_vcf_file(data_path) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported file format: {data_path.suffix}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading data from {data_path}: {e}") |
|
|
raise |
|
|
|
|
|
def _process_data(self) -> List[Dict[str, Any]]: |
|
|
"""Raw hierarchical data into dataset format.""" |
|
|
|
|
|
processed_samples = [] |
|
|
|
|
|
for sample_id, sample_data in self.raw_data.items(): |
|
|
try: |
|
|
|
|
|
standardized_sample = self._standardize_sample_format(sample_data) |
|
|
|
|
|
|
|
|
if self._should_include_sample(standardized_sample): |
|
|
|
|
|
encoded_sample = self.tokenizer.encode_hierarchical_sample(standardized_sample) |
|
|
|
|
|
processed_sample = { |
|
|
'sample_id': sample_id, |
|
|
'encoded_data': encoded_sample, |
|
|
'raw_data': standardized_sample if not self.cache_processed_data else None |
|
|
} |
|
|
|
|
|
processed_samples.append(processed_sample) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Error processing sample {sample_id}: {e}") |
|
|
continue |
|
|
|
|
|
return processed_samples |
|
|
|
|
|
def _standardize_sample_format(self, sample_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
|
|
|
if 'mutations' in sample_data: |
|
|
|
|
|
return self._convert_flat_to_hierarchical(sample_data['mutations']) |
|
|
|
|
|
elif isinstance(sample_data, dict) and all( |
|
|
isinstance(v, dict) for v in sample_data.values() |
|
|
): |
|
|
|
|
|
return sample_data |
|
|
|
|
|
else: |
|
|
|
|
|
return self._convert_flat_to_hierarchical(sample_data) |
|
|
|
|
|
def _convert_flat_to_hierarchical(self, mutations: List[Dict]) -> Dict[str, Any]: |
|
|
"""Convert flat mutation list to hierarchical format.""" |
|
|
|
|
|
hierarchical = {} |
|
|
|
|
|
for mutation in mutations: |
|
|
|
|
|
pathway = mutation.get('pathway', 'Unknown_Pathway') |
|
|
chromosome = mutation.get('chromosome', mutation.get('chrom', 'Unknown')) |
|
|
gene = mutation.get('gene', mutation.get('gene_id', 'Unknown_Gene')) |
|
|
|
|
|
|
|
|
if pathway not in hierarchical: |
|
|
hierarchical[pathway] = {} |
|
|
if chromosome not in hierarchical[pathway]: |
|
|
hierarchical[pathway][chromosome] = {} |
|
|
if gene not in hierarchical[pathway][chromosome]: |
|
|
hierarchical[pathway][chromosome][gene] = [] |
|
|
|
|
|
|
|
|
hierarchical[pathway][chromosome][gene].append(mutation) |
|
|
|
|
|
return hierarchical |
|
|
|
|
|
def _should_include_sample(self, sample_data: Dict[str, Any]) -> bool: |
|
|
"""Determine if sample should be included based on filtering criteria.""" |
|
|
|
|
|
|
|
|
total_mutations = 0 |
|
|
for pathway_data in sample_data.values(): |
|
|
for chrom_data in pathway_data.values(): |
|
|
for gene_mutations in chrom_data.values(): |
|
|
total_mutations += len(gene_mutations) |
|
|
|
|
|
|
|
|
if total_mutations < self.config.min_mutations_per_sample: |
|
|
return False |
|
|
|
|
|
if total_mutations > self.config.max_mutations_per_sample: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def _validate_data(self) -> None: |
|
|
|
|
|
if len(self.processed_data) == 0: |
|
|
raise ValueError("No valid samples found in dataset") |
|
|
|
|
|
if self.labels is not None: |
|
|
if len(self.labels) != len(self.processed_data): |
|
|
raise ValueError( |
|
|
f"Number of labels ({len(self.labels)}) doesn't match " |
|
|
f"number of samples ({len(self.processed_data)})" |
|
|
) |
|
|
|
|
|
def _compute_statistics(self) -> Dict[str, Any]: |
|
|
"""CDataset statistics.""" |
|
|
|
|
|
stats = { |
|
|
'num_samples': len(self.processed_data), |
|
|
'num_pathways': set(), |
|
|
'num_chromosomes': set(), |
|
|
'num_genes': set(), |
|
|
'mutations_per_sample': [], |
|
|
'genes_per_sample': [], |
|
|
'pathways_per_sample': [] |
|
|
} |
|
|
|
|
|
for sample in self.processed_data: |
|
|
encoded_data = sample['encoded_data'] |
|
|
|
|
|
sample_pathways = len(encoded_data) |
|
|
sample_genes = 0 |
|
|
sample_mutations = 0 |
|
|
|
|
|
for pathway_token, chromosomes in encoded_data.items(): |
|
|
stats['num_pathways'].add(pathway_token) |
|
|
|
|
|
for chrom_token, genes in chromosomes.items(): |
|
|
stats['num_chromosomes'].add(chrom_token) |
|
|
|
|
|
for gene_token, mutations in genes.items(): |
|
|
stats['num_genes'].add(gene_token) |
|
|
sample_genes += 1 |
|
|
|
|
|
|
|
|
if 'impact' in mutations: |
|
|
sample_mutations += len(mutations['impact']) |
|
|
|
|
|
stats['mutations_per_sample'].append(sample_mutations) |
|
|
stats['genes_per_sample'].append(sample_genes) |
|
|
stats['pathways_per_sample'].append(sample_pathways) |
|
|
|
|
|
|
|
|
stats['unique_pathways'] = len(stats['num_pathways']) |
|
|
stats['unique_chromosomes'] = len(stats['num_chromosomes']) |
|
|
stats['unique_genes'] = len(stats['num_genes']) |
|
|
|
|
|
|
|
|
if stats['mutations_per_sample']: |
|
|
stats['avg_mutations_per_sample'] = np.mean(stats['mutations_per_sample']) |
|
|
stats['std_mutations_per_sample'] = np.std(stats['mutations_per_sample']) |
|
|
|
|
|
if stats['genes_per_sample']: |
|
|
stats['avg_genes_per_sample'] = np.mean(stats['genes_per_sample']) |
|
|
stats['std_genes_per_sample'] = np.std(stats['genes_per_sample']) |
|
|
|
|
|
|
|
|
del stats['num_pathways'], stats['num_chromosomes'], stats['num_genes'] |
|
|
|
|
|
return stats |
|
|
|
|
|
def __len__(self) -> int: |
|
|
"""Number of samples in the dataset.""" |
|
|
return len(self.processed_data) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, Any]: |
|
|
"""Single sample from the dataset.""" |
|
|
|
|
|
if idx >= len(self.processed_data): |
|
|
raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}") |
|
|
|
|
|
sample = self.processed_data[idx].copy() |
|
|
|
|
|
|
|
|
if self.transform: |
|
|
sample['encoded_data'] = self.transform(sample['encoded_data']) |
|
|
|
|
|
|
|
|
if self.labels is not None: |
|
|
label = self.labels[idx] |
|
|
if self.target_transform: |
|
|
label = self.target_transform(label) |
|
|
sample['label'] = label |
|
|
|
|
|
return sample |
|
|
|
|
|
def get_sample_by_id(self, sample_id: str) -> Optional[Dict[str, Any]]: |
|
|
for i, sample in enumerate(self.processed_data): |
|
|
if sample['sample_id'] == sample_id: |
|
|
return self.__getitem__(i) |
|
|
return None |
|
|
|
|
|
def get_statistics(self) -> Dict[str, Any]: |
|
|
return self.stats.copy() |
|
|
|
|
|
def save_dataset(self, save_path: Union[str, Path], format: str = 'pickle') -> None: |
|
|
""" |
|
|
Args: |
|
|
save_path: Path to save the dataset |
|
|
format: Save format ('pickle', 'json') |
|
|
""" |
|
|
save_path = Path(save_path) |
|
|
save_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
dataset_info = { |
|
|
'processed_data': self.processed_data, |
|
|
'labels': self.labels.tolist() if isinstance(self.labels, np.ndarray) else self.labels, |
|
|
'stats': self.stats, |
|
|
'config': self.config.__dict__ if hasattr(self.config, '__dict__') else None |
|
|
} |
|
|
|
|
|
if format.lower() == 'pickle': |
|
|
with open(save_path, 'wb') as f: |
|
|
pickle.dump(dataset_info, f) |
|
|
|
|
|
elif format.lower() == 'json': |
|
|
with open(save_path, 'w') as f: |
|
|
json.dump(dataset_info, f, indent=2, default=str) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported save format: {format}") |
|
|
|
|
|
logger.info(f"Dataset saved to {save_path}") |
|
|
|
|
|
@classmethod |
|
|
def load_dataset(cls, |
|
|
load_path: Union[str, Path], |
|
|
tokenizer: HierarchicalVCFTokenizer, |
|
|
format: str = 'auto') -> 'HierarchicalVCFDataset': |
|
|
""" |
|
|
Args: |
|
|
load_path: Path to load the dataset from |
|
|
tokenizer: Tokenizer instance |
|
|
format: Load format ('pickle', 'json', 'auto') |
|
|
|
|
|
Returns: |
|
|
Loaded dataset instance |
|
|
""" |
|
|
load_path = Path(load_path) |
|
|
|
|
|
if not load_path.exists(): |
|
|
raise FileNotFoundError(f"Dataset file not found: {load_path}") |
|
|
|
|
|
|
|
|
if format == 'auto': |
|
|
format = 'pickle' if load_path.suffix == '.pkl' else 'json' |
|
|
|
|
|
|
|
|
if format.lower() == 'pickle': |
|
|
with open(load_path, 'rb') as f: |
|
|
dataset_info = pickle.load(f) |
|
|
|
|
|
elif format.lower() == 'json': |
|
|
with open(load_path, 'r') as f: |
|
|
dataset_info = json.load(f) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported load format: {format}") |
|
|
|
|
|
|
|
|
dataset = cls.__new__(cls) |
|
|
dataset.tokenizer = tokenizer |
|
|
dataset.processed_data = dataset_info['processed_data'] |
|
|
dataset.labels = dataset_info.get('labels') |
|
|
dataset.stats = dataset_info.get('stats', {}) |
|
|
dataset.config = dataset_info.get('config', DataConfig()) |
|
|
dataset.transform = None |
|
|
dataset.target_transform = None |
|
|
dataset.cache_processed_data = True |
|
|
|
|
|
return dataset |
|
|
|
|
|
|
|
|
class HierarchicalVCFDataModule: |
|
|
""" |
|
|
Manage train/validation/test splits of hierarchical VCF data. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
data_source: Union[str, Path, Dict], |
|
|
tokenizer: HierarchicalVCFTokenizer, |
|
|
config: Optional[DataConfig] = None, |
|
|
labels: Optional[Union[List, np.ndarray]] = None, |
|
|
train_split: float = 0.8, |
|
|
val_split: float = 0.1, |
|
|
test_split: float = 0.1, |
|
|
stratify: bool = True, |
|
|
random_seed: int = 42): |
|
|
""" |
|
|
Args: |
|
|
data_source: Source of the data |
|
|
tokenizer: Tokenizer for encoding |
|
|
config: Data configuration |
|
|
labels: Labels for supervised learning |
|
|
train_split: Proportion for training |
|
|
val_split: Proportion for validation |
|
|
test_split: Proportion for testing |
|
|
stratify: Whether to stratify splits by labels |
|
|
random_seed: Random seed for reproducibility |
|
|
""" |
|
|
|
|
|
self.config = config or DataConfig() |
|
|
self.tokenizer = tokenizer |
|
|
self.train_split = train_split |
|
|
self.val_split = val_split |
|
|
self.test_split = test_split |
|
|
self.stratify = stratify |
|
|
self.random_seed = random_seed |
|
|
|
|
|
|
|
|
if abs(train_split + val_split + test_split - 1.0) > 1e-6: |
|
|
raise ValueError("Train, validation, and test splits must sum to 1.0") |
|
|
|
|
|
|
|
|
self.full_dataset = HierarchicalVCFDataset( |
|
|
data_source=data_source, |
|
|
tokenizer=tokenizer, |
|
|
config=config, |
|
|
labels=labels |
|
|
) |
|
|
|
|
|
|
|
|
self.train_dataset, self.val_dataset, self.test_dataset = self._create_splits() |
|
|
|
|
|
logger.info(f"Data module initialized:") |
|
|
logger.info(f" Train: {len(self.train_dataset)} samples") |
|
|
logger.info(f" Validation: {len(self.val_dataset)} samples") |
|
|
logger.info(f" Test: {len(self.test_dataset)} samples") |
|
|
|
|
|
def _create_splits(self) -> Tuple[Dataset, Dataset, Dataset]: |
|
|
|
|
|
np.random.seed(self.random_seed) |
|
|
|
|
|
indices = np.arange(len(self.full_dataset)) |
|
|
|
|
|
if self.stratify and self.full_dataset.labels is not None: |
|
|
|
|
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
|
|
|
train_idx, temp_idx = train_test_split( |
|
|
indices, |
|
|
test_size=(self.val_split + self.test_split), |
|
|
stratify=[self.full_dataset.labels[i] for i in indices], |
|
|
random_state=self.random_seed |
|
|
) |
|
|
|
|
|
|
|
|
if self.test_split > 0: |
|
|
val_idx, test_idx = train_test_split( |
|
|
temp_idx, |
|
|
test_size=self.test_split / (self.val_split + self.test_split), |
|
|
stratify=[self.full_dataset.labels[i] for i in temp_idx], |
|
|
random_state=self.random_seed |
|
|
) |
|
|
else: |
|
|
val_idx = temp_idx |
|
|
test_idx = np.array([]) |
|
|
|
|
|
else: |
|
|
|
|
|
np.random.shuffle(indices) |
|
|
|
|
|
train_end = int(self.train_split * len(indices)) |
|
|
val_end = int((self.train_split + self.val_split) * len(indices)) |
|
|
|
|
|
train_idx = indices[:train_end] |
|
|
val_idx = indices[train_end:val_end] |
|
|
test_idx = indices[val_end:] |
|
|
|
|
|
|
|
|
train_dataset = self._create_subset(train_idx) |
|
|
val_dataset = self._create_subset(val_idx) |
|
|
test_dataset = self._create_subset(test_idx) |
|
|
|
|
|
return train_dataset, val_dataset, test_dataset |
|
|
|
|
|
def _create_subset(self, indices: np.ndarray) -> Dataset: |
|
|
"""Create a subset dataset from indices.""" |
|
|
|
|
|
subset_data = [self.full_dataset.processed_data[i] for i in indices] |
|
|
subset_labels = None |
|
|
|
|
|
if self.full_dataset.labels is not None: |
|
|
if isinstance(self.full_dataset.labels, np.ndarray): |
|
|
subset_labels = self.full_dataset.labels[indices] |
|
|
else: |
|
|
subset_labels = [self.full_dataset.labels[i] for i in indices] |
|
|
|
|
|
|
|
|
dataset = HierarchicalVCFDataset.__new__(HierarchicalVCFDataset) |
|
|
dataset.tokenizer = self.tokenizer |
|
|
dataset.config = self.config |
|
|
dataset.processed_data = subset_data |
|
|
dataset.labels = subset_labels |
|
|
dataset.transform = None |
|
|
dataset.target_transform = None |
|
|
dataset.cache_processed_data = True |
|
|
dataset.stats = dataset._compute_statistics() |
|
|
|
|
|
return dataset |
|
|
|
|
|
def get_dataloaders(self, |
|
|
batch_size: int = 16, |
|
|
num_workers: int = 0, |
|
|
collate_fn: Optional[Callable] = None) -> Tuple[DataLoader, DataLoader, DataLoader]: |
|
|
""" |
|
|
Args: |
|
|
batch_size: Batch size for data loading |
|
|
num_workers: Number of worker processes |
|
|
collate_fn: Custom collate function |
|
|
|
|
|
Returns: |
|
|
Tuple of (train_loader, val_loader, test_loader) |
|
|
""" |
|
|
|
|
|
if collate_fn is None: |
|
|
collate_fn = HierarchicalDataCollator(self.tokenizer) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
self.train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collate_fn |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
self.val_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collate_fn |
|
|
) |
|
|
|
|
|
test_loader = DataLoader( |
|
|
self.test_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collate_fn |
|
|
) |
|
|
|
|
|
return train_loader, val_loader, test_loader |
|
|
|
|
|
|
|
|
class HuggingFaceDatasetAdapter: |
|
|
""" |
|
|
Convert hierarchical VCF data to Hugging Face Dataset format. |
|
|
""" |
|
|
|
|
|
def __init__(self, vcf_dataset: HierarchicalVCFDataset): |
|
|
self.vcf_dataset = vcf_dataset |
|
|
|
|
|
def to_huggingface_dataset(self) -> DatasetDict: |
|
|
""" |
|
|
Returns: |
|
|
HuggingFace DatasetDict |
|
|
""" |
|
|
|
|
|
|
|
|
flattened_data = [] |
|
|
|
|
|
for sample in self.vcf_dataset.processed_data: |
|
|
sample_id = sample['sample_id'] |
|
|
encoded_data = sample['encoded_data'] |
|
|
|
|
|
|
|
|
flattened_sample = { |
|
|
'sample_id': sample_id, |
|
|
'pathways': list(encoded_data.keys()), |
|
|
'num_pathways': len(encoded_data), |
|
|
'encoded_mutations': self._flatten_mutations(encoded_data) |
|
|
} |
|
|
|
|
|
flattened_data.append(flattened_sample) |
|
|
|
|
|
|
|
|
if self.vcf_dataset.labels is not None: |
|
|
for i, sample in enumerate(flattened_data): |
|
|
sample['label'] = self.vcf_dataset.labels[i] |
|
|
|
|
|
|
|
|
hf_dataset = HFDataset.from_list(flattened_data) |
|
|
|
|
|
return DatasetDict({'train': hf_dataset}) |
|
|
|
|
|
def _flatten_mutations(self, encoded_data: Dict) -> Dict[str, List]: |
|
|
"""Flatten hierarchical mutations for HF compatibility.""" |
|
|
|
|
|
all_impacts = [] |
|
|
all_refs = [] |
|
|
all_alts = [] |
|
|
|
|
|
for pathway_token, chromosomes in encoded_data.items(): |
|
|
for chrom_token, genes in chromosomes.items(): |
|
|
for gene_token, mutations in genes.items(): |
|
|
if 'impact' in mutations: |
|
|
all_impacts.extend(mutations['impact']) |
|
|
if 'ref' in mutations: |
|
|
all_refs.extend(mutations['ref']) |
|
|
if 'alt' in mutations: |
|
|
all_alts.extend(mutations['alt']) |
|
|
|
|
|
return { |
|
|
'impacts': all_impacts, |
|
|
'refs': all_refs, |
|
|
'alts': all_alts |
|
|
} |
|
|
|
|
|
|
|
|
def create_dataset_from_config(config_manager: ConfigManager, |
|
|
tokenizer: HierarchicalVCFTokenizer, |
|
|
labels: Optional[List] = None) -> HierarchicalVCFDataset: |
|
|
|
|
|
data_config = config_manager.data_config |
|
|
|
|
|
if not data_config.vcf_file_path: |
|
|
raise ValueError("VCF file path not specified in configuration") |
|
|
|
|
|
return HierarchicalVCFDataset( |
|
|
data_source=data_config.vcf_file_path, |
|
|
tokenizer=tokenizer, |
|
|
config=data_config, |
|
|
labels=labels |
|
|
) |
|
|
|
|
|
|
|
|
def create_data_module_from_config(config_manager: ConfigManager, |
|
|
tokenizer: HierarchicalVCFTokenizer, |
|
|
labels: Optional[List] = None) -> HierarchicalVCFDataModule: |
|
|
|
|
|
data_config = config_manager.data_config |
|
|
|
|
|
if not data_config.vcf_file_path: |
|
|
raise ValueError("VCF file path not specified in configuration") |
|
|
|
|
|
return HierarchicalVCFDataModule( |
|
|
data_source=data_config.vcf_file_path, |
|
|
tokenizer=tokenizer, |
|
|
config=data_config, |
|
|
labels=labels |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def create_synthetic_labels(dataset: HierarchicalVCFDataset, |
|
|
label_type: str = 'random', |
|
|
num_classes: int = 2) -> np.ndarray: |
|
|
""" |
|
|
Create synthetic labels for testing purposes. |
|
|
|
|
|
Args: |
|
|
dataset: VCF dataset |
|
|
label_type: Type of labels ('random', 'mutation_count_based') |
|
|
num_classes: Number of classes for classification |
|
|
|
|
|
Returns: |
|
|
Array of synthetic labels |
|
|
""" |
|
|
|
|
|
num_samples = len(dataset) |
|
|
|
|
|
if label_type == 'random': |
|
|
return np.random.randint(0, num_classes, size=num_samples) |
|
|
|
|
|
elif label_type == 'mutation_count_based': |
|
|
|
|
|
mutation_counts = dataset.stats['mutations_per_sample'] |
|
|
threshold = np.median(mutation_counts) |
|
|
|
|
|
labels = [] |
|
|
for count in mutation_counts: |
|
|
if num_classes == 2: |
|
|
labels.append(1 if count > threshold else 0) |
|
|
else: |
|
|
|
|
|
percentiles = np.linspace(0, 100, num_classes + 1) |
|
|
thresholds = np.percentile(mutation_counts, percentiles[1:-1]) |
|
|
|
|
|
label = 0 |
|
|
for i, t in enumerate(thresholds): |
|
|
if count > t: |
|
|
label = i + 1 |
|
|
else: |
|
|
break |
|
|
labels.append(label) |
|
|
|
|
|
return np.array(labels) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown label type: {label_type}") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from tokenizer import create_tokenizer_from_config |
|
|
|
|
|
|
|
|
config_manager = ConfigManager() |
|
|
config_manager.data_config.vcf_file_path = "example_data.json" |
|
|
|
|
|
|
|
|
tokenizer = create_tokenizer_from_config(config_manager) |
|
|
|
|
|
|
|
|
example_data = { |
|
|
'sample1': { |
|
|
'pathway1': { |
|
|
'chr1': { |
|
|
'gene1': [ |
|
|
{'impact': 'HIGH', 'reference': 'A', 'alternate': 'T'}, |
|
|
{'impact': 'MODERATE', 'reference': 'G', 'alternate': 'C'} |
|
|
] |
|
|
} |
|
|
} |
|
|
}, |
|
|
'sample2': { |
|
|
'pathway2': { |
|
|
'chr2': { |
|
|
'gene2': [ |
|
|
{'impact': 'LOW', 'reference': 'T', 'alternate': 'A'} |
|
|
] |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
tokenizer.build_vocabulary(example_data) |
|
|
|
|
|
|
|
|
dataset = HierarchicalVCFDataset( |
|
|
data_source=example_data, |
|
|
tokenizer=tokenizer |
|
|
) |
|
|
|
|
|
|
|
|
labels = create_synthetic_labels(dataset, label_type='random', num_classes=2) |
|
|
dataset.labels = labels |
|
|
|
|
|
|
|
|
data_module = HierarchicalVCFDataModule( |
|
|
data_source=example_data, |
|
|
tokenizer=tokenizer, |
|
|
labels=labels, |
|
|
train_split=0.6, |
|
|
val_split=0.2, |
|
|
test_split=0.2 |
|
|
) |
|
|
|
|
|
|
|
|
train_loader, val_loader, test_loader = data_module.get_dataloaders(batch_size=2) |
|
|
|
|
|
|
|
|
for batch in train_loader: |
|
|
print(f"Batch size: {batch['batch_size']}") |
|
|
print(f"Sample IDs: {[s.get('sample_id', 'N/A') for s in batch['samples']]}") |
|
|
break |
|
|
|
|
|
print(f"Dataset statistics: {dataset.get_statistics()}") |