abd-ur commited on
Commit
c278a8f
·
verified ·
1 Parent(s): 3dec9cb

Create dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +763 -0
dataset.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides PyTorch Dataset implementations for hierarchical VCF data
3
+ """
4
+
5
+ import torch
6
+ import json
7
+ import pickle
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Dict, List, Tuple, Optional, Union, Any, Callable
11
+ from torch.utils.data import Dataset, DataLoader
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+ from datasets import Dataset as HFDataset, DatasetDict
16
+ from transformers import PreTrainedTokenizer
17
+
18
+ from config import DataConfig, ModelConfig, ConfigManager
19
+ from parser import VCFParser, MutationRecord
20
+ from tokenizer import HierarchicalVCFTokenizer, HierarchicalDataCollator
21
+
22
+
23
+ # Configure logging
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class HierarchicalVCFDataset(Dataset):
29
+
30
+ def __init__(self,
31
+ data_source: Union[str, Path, Dict, List],
32
+ tokenizer: HierarchicalVCFTokenizer,
33
+ config: Optional[DataConfig] = None,
34
+ labels: Optional[Union[List, np.ndarray]] = None,
35
+ transform: Optional[Callable] = None,
36
+ target_transform: Optional[Callable] = None,
37
+ cache_processed_data: bool = True):
38
+ """
39
+ Initialize the Hierarchical VCF Dataset.
40
+ Args:
41
+ data_source: Path to data file, or preprocessed data dict/list
42
+ tokenizer: Tokenizer for encoding mutations
43
+ config: Data configuration
44
+ labels: Optional labels for supervised learning
45
+ transform: Optional transform to apply to samples
46
+ target_transform: Optional transform to apply to labels
47
+ cache_processed_data: Whether to cache processed data
48
+ """
49
+
50
+ self.config = config or DataConfig()
51
+ self.tokenizer = tokenizer
52
+ self.labels = labels
53
+ self.transform = transform
54
+ self.target_transform = target_transform
55
+ self.cache_processed_data = cache_processed_data
56
+
57
+ # Load and process data
58
+ self.raw_data = self._load_data(data_source)
59
+ self.processed_data = self._process_data()
60
+
61
+ # Validate data consistency
62
+ self._validate_data()
63
+
64
+ # Dataset statistics
65
+ self.stats = self._compute_statistics()
66
+
67
+ logger.info(f"Dataset initialized with {len(self.processed_data)} samples")
68
+ logger.info(f"Dataset statistics: {self.stats}")
69
+
70
+ def _load_data(self, data_source: Union[str, Path, Dict, List]) -> Dict[str, Any]:
71
+
72
+ if isinstance(data_source, (dict, list)):
73
+ # Data already loaded
74
+ if isinstance(data_source, list):
75
+ # Convert list to dict format
76
+ return {f"sample_{i}": sample for i, sample in enumerate(data_source)}
77
+ return data_source
78
+
79
+ # Load from file
80
+ data_path = Path(data_source)
81
+
82
+ if not data_path.exists():
83
+ raise FileNotFoundError(f"Data file not found: {data_path}")
84
+
85
+ try:
86
+ if data_path.suffix.lower() == '.json':
87
+ with open(data_path, 'r') as f:
88
+ return json.load(f)
89
+
90
+ elif data_path.suffix.lower() == '.pkl':
91
+ with open(data_path, 'rb') as f:
92
+ return pickle.load(f)
93
+
94
+ elif data_path.suffix.lower() == '.vcf':
95
+ # Parse VCF file directly
96
+ parser = VCFParser(config=self.config)
97
+ return parser.parse_vcf_file(data_path)
98
+
99
+ else:
100
+ raise ValueError(f"Unsupported file format: {data_path.suffix}")
101
+
102
+ except Exception as e:
103
+ logger.error(f"Error loading data from {data_path}: {e}")
104
+ raise
105
+
106
+ def _process_data(self) -> List[Dict[str, Any]]:
107
+ """Raw hierarchical data into dataset format."""
108
+
109
+ processed_samples = []
110
+
111
+ for sample_id, sample_data in self.raw_data.items():
112
+ try:
113
+ # Convert to standard format if needed
114
+ standardized_sample = self._standardize_sample_format(sample_data)
115
+
116
+ # Filter samples based on configuration
117
+ if self._should_include_sample(standardized_sample):
118
+ # Encode the sample
119
+ encoded_sample = self.tokenizer.encode_hierarchical_sample(standardized_sample)
120
+
121
+ processed_sample = {
122
+ 'sample_id': sample_id,
123
+ 'encoded_data': encoded_sample,
124
+ 'raw_data': standardized_sample if not self.cache_processed_data else None
125
+ }
126
+
127
+ processed_samples.append(processed_sample)
128
+
129
+ except Exception as e:
130
+ logger.warning(f"Error processing sample {sample_id}: {e}")
131
+ continue
132
+
133
+ return processed_samples
134
+
135
+ def _standardize_sample_format(self, sample_data: Dict[str, Any]) -> Dict[str, Any]:
136
+
137
+ # Handle different input formats
138
+ if 'mutations' in sample_data:
139
+ # Format: {'mutations': [...]}
140
+ return self._convert_flat_to_hierarchical(sample_data['mutations'])
141
+
142
+ elif isinstance(sample_data, dict) and all(
143
+ isinstance(v, dict) for v in sample_data.values()
144
+ ):
145
+ # Already in hierarchical format
146
+ return sample_data
147
+
148
+ else:
149
+ # Assume it's a list of mutations
150
+ return self._convert_flat_to_hierarchical(sample_data)
151
+
152
+ def _convert_flat_to_hierarchical(self, mutations: List[Dict]) -> Dict[str, Any]:
153
+ """Convert flat mutation list to hierarchical format."""
154
+
155
+ hierarchical = {}
156
+
157
+ for mutation in mutations:
158
+ # Extract hierarchical keys
159
+ pathway = mutation.get('pathway', 'Unknown_Pathway')
160
+ chromosome = mutation.get('chromosome', mutation.get('chrom', 'Unknown'))
161
+ gene = mutation.get('gene', mutation.get('gene_id', 'Unknown_Gene'))
162
+
163
+ # Initialize nested structure
164
+ if pathway not in hierarchical:
165
+ hierarchical[pathway] = {}
166
+ if chromosome not in hierarchical[pathway]:
167
+ hierarchical[pathway][chromosome] = {}
168
+ if gene not in hierarchical[pathway][chromosome]:
169
+ hierarchical[pathway][chromosome][gene] = []
170
+
171
+ # Add mutation
172
+ hierarchical[pathway][chromosome][gene].append(mutation)
173
+
174
+ return hierarchical
175
+
176
+ def _should_include_sample(self, sample_data: Dict[str, Any]) -> bool:
177
+ """Determine if sample should be included based on filtering criteria."""
178
+
179
+ # Count total mutations
180
+ total_mutations = 0
181
+ for pathway_data in sample_data.values():
182
+ for chrom_data in pathway_data.values():
183
+ for gene_mutations in chrom_data.values():
184
+ total_mutations += len(gene_mutations)
185
+
186
+ # Apply filters
187
+ if total_mutations < self.config.min_mutations_per_sample:
188
+ return False
189
+
190
+ if total_mutations > self.config.max_mutations_per_sample:
191
+ return False
192
+
193
+ return True
194
+
195
+ def _validate_data(self) -> None:
196
+
197
+ if len(self.processed_data) == 0:
198
+ raise ValueError("No valid samples found in dataset")
199
+
200
+ if self.labels is not None:
201
+ if len(self.labels) != len(self.processed_data):
202
+ raise ValueError(
203
+ f"Number of labels ({len(self.labels)}) doesn't match "
204
+ f"number of samples ({len(self.processed_data)})"
205
+ )
206
+
207
+ def _compute_statistics(self) -> Dict[str, Any]:
208
+ """CDataset statistics."""
209
+
210
+ stats = {
211
+ 'num_samples': len(self.processed_data),
212
+ 'num_pathways': set(),
213
+ 'num_chromosomes': set(),
214
+ 'num_genes': set(),
215
+ 'mutations_per_sample': [],
216
+ 'genes_per_sample': [],
217
+ 'pathways_per_sample': []
218
+ }
219
+
220
+ for sample in self.processed_data:
221
+ encoded_data = sample['encoded_data']
222
+
223
+ sample_pathways = len(encoded_data)
224
+ sample_genes = 0
225
+ sample_mutations = 0
226
+
227
+ for pathway_token, chromosomes in encoded_data.items():
228
+ stats['num_pathways'].add(pathway_token)
229
+
230
+ for chrom_token, genes in chromosomes.items():
231
+ stats['num_chromosomes'].add(chrom_token)
232
+
233
+ for gene_token, mutations in genes.items():
234
+ stats['num_genes'].add(gene_token)
235
+ sample_genes += 1
236
+
237
+ # Count mutations (assuming 'impact' field exists)
238
+ if 'impact' in mutations:
239
+ sample_mutations += len(mutations['impact'])
240
+
241
+ stats['mutations_per_sample'].append(sample_mutations)
242
+ stats['genes_per_sample'].append(sample_genes)
243
+ stats['pathways_per_sample'].append(sample_pathways)
244
+
245
+ # Convert sets to counts
246
+ stats['unique_pathways'] = len(stats['num_pathways'])
247
+ stats['unique_chromosomes'] = len(stats['num_chromosomes'])
248
+ stats['unique_genes'] = len(stats['num_genes'])
249
+
250
+ # Compute summary statistics
251
+ if stats['mutations_per_sample']:
252
+ stats['avg_mutations_per_sample'] = np.mean(stats['mutations_per_sample'])
253
+ stats['std_mutations_per_sample'] = np.std(stats['mutations_per_sample'])
254
+
255
+ if stats['genes_per_sample']:
256
+ stats['avg_genes_per_sample'] = np.mean(stats['genes_per_sample'])
257
+ stats['std_genes_per_sample'] = np.std(stats['genes_per_sample'])
258
+
259
+ # Remove raw sets
260
+ del stats['num_pathways'], stats['num_chromosomes'], stats['num_genes']
261
+
262
+ return stats
263
+
264
+ def __len__(self) -> int:
265
+ """Number of samples in the dataset."""
266
+ return len(self.processed_data)
267
+
268
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
269
+ """Single sample from the dataset."""
270
+
271
+ if idx >= len(self.processed_data):
272
+ raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}")
273
+
274
+ sample = self.processed_data[idx].copy()
275
+
276
+ # Apply transforms
277
+ if self.transform:
278
+ sample['encoded_data'] = self.transform(sample['encoded_data'])
279
+
280
+ # Add label if available
281
+ if self.labels is not None:
282
+ label = self.labels[idx]
283
+ if self.target_transform:
284
+ label = self.target_transform(label)
285
+ sample['label'] = label
286
+
287
+ return sample
288
+
289
+ def get_sample_by_id(self, sample_id: str) -> Optional[Dict[str, Any]]:
290
+ for i, sample in enumerate(self.processed_data):
291
+ if sample['sample_id'] == sample_id:
292
+ return self.__getitem__(i)
293
+ return None
294
+
295
+ def get_statistics(self) -> Dict[str, Any]:
296
+ return self.stats.copy()
297
+
298
+ def save_dataset(self, save_path: Union[str, Path], format: str = 'pickle') -> None:
299
+ """
300
+ Args:
301
+ save_path: Path to save the dataset
302
+ format: Save format ('pickle', 'json')
303
+ """
304
+ save_path = Path(save_path)
305
+ save_path.parent.mkdir(parents=True, exist_ok=True)
306
+
307
+ dataset_info = {
308
+ 'processed_data': self.processed_data,
309
+ 'labels': self.labels.tolist() if isinstance(self.labels, np.ndarray) else self.labels,
310
+ 'stats': self.stats,
311
+ 'config': self.config.__dict__ if hasattr(self.config, '__dict__') else None
312
+ }
313
+
314
+ if format.lower() == 'pickle':
315
+ with open(save_path, 'wb') as f:
316
+ pickle.dump(dataset_info, f)
317
+
318
+ elif format.lower() == 'json':
319
+ with open(save_path, 'w') as f:
320
+ json.dump(dataset_info, f, indent=2, default=str)
321
+
322
+ else:
323
+ raise ValueError(f"Unsupported save format: {format}")
324
+
325
+ logger.info(f"Dataset saved to {save_path}")
326
+
327
+ @classmethod
328
+ def load_dataset(cls,
329
+ load_path: Union[str, Path],
330
+ tokenizer: HierarchicalVCFTokenizer,
331
+ format: str = 'auto') -> 'HierarchicalVCFDataset':
332
+ """
333
+ Args:
334
+ load_path: Path to load the dataset from
335
+ tokenizer: Tokenizer instance
336
+ format: Load format ('pickle', 'json', 'auto')
337
+
338
+ Returns:
339
+ Loaded dataset instance
340
+ """
341
+ load_path = Path(load_path)
342
+
343
+ if not load_path.exists():
344
+ raise FileNotFoundError(f"Dataset file not found: {load_path}")
345
+
346
+ # Determine format
347
+ if format == 'auto':
348
+ format = 'pickle' if load_path.suffix == '.pkl' else 'json'
349
+
350
+ # Load data
351
+ if format.lower() == 'pickle':
352
+ with open(load_path, 'rb') as f:
353
+ dataset_info = pickle.load(f)
354
+
355
+ elif format.lower() == 'json':
356
+ with open(load_path, 'r') as f:
357
+ dataset_info = json.load(f)
358
+
359
+ else:
360
+ raise ValueError(f"Unsupported load format: {format}")
361
+
362
+ # Create dataset instance
363
+ dataset = cls.__new__(cls)
364
+ dataset.tokenizer = tokenizer
365
+ dataset.processed_data = dataset_info['processed_data']
366
+ dataset.labels = dataset_info.get('labels')
367
+ dataset.stats = dataset_info.get('stats', {})
368
+ dataset.config = dataset_info.get('config', DataConfig())
369
+ dataset.transform = None
370
+ dataset.target_transform = None
371
+ dataset.cache_processed_data = True
372
+
373
+ return dataset
374
+
375
+
376
+ class HierarchicalVCFDataModule:
377
+ """
378
+ Manage train/validation/test splits of hierarchical VCF data.
379
+ """
380
+
381
+ def __init__(self,
382
+ data_source: Union[str, Path, Dict],
383
+ tokenizer: HierarchicalVCFTokenizer,
384
+ config: Optional[DataConfig] = None,
385
+ labels: Optional[Union[List, np.ndarray]] = None,
386
+ train_split: float = 0.8,
387
+ val_split: float = 0.1,
388
+ test_split: float = 0.1,
389
+ stratify: bool = True,
390
+ random_seed: int = 42):
391
+ """
392
+ Args:
393
+ data_source: Source of the data
394
+ tokenizer: Tokenizer for encoding
395
+ config: Data configuration
396
+ labels: Labels for supervised learning
397
+ train_split: Proportion for training
398
+ val_split: Proportion for validation
399
+ test_split: Proportion for testing
400
+ stratify: Whether to stratify splits by labels
401
+ random_seed: Random seed for reproducibility
402
+ """
403
+
404
+ self.config = config or DataConfig()
405
+ self.tokenizer = tokenizer
406
+ self.train_split = train_split
407
+ self.val_split = val_split
408
+ self.test_split = test_split
409
+ self.stratify = stratify
410
+ self.random_seed = random_seed
411
+
412
+ # Validate splits
413
+ if abs(train_split + val_split + test_split - 1.0) > 1e-6:
414
+ raise ValueError("Train, validation, and test splits must sum to 1.0")
415
+
416
+ # Load full dataset
417
+ self.full_dataset = HierarchicalVCFDataset(
418
+ data_source=data_source,
419
+ tokenizer=tokenizer,
420
+ config=config,
421
+ labels=labels
422
+ )
423
+
424
+ # Create splits
425
+ self.train_dataset, self.val_dataset, self.test_dataset = self._create_splits()
426
+
427
+ logger.info(f"Data module initialized:")
428
+ logger.info(f" Train: {len(self.train_dataset)} samples")
429
+ logger.info(f" Validation: {len(self.val_dataset)} samples")
430
+ logger.info(f" Test: {len(self.test_dataset)} samples")
431
+
432
+ def _create_splits(self) -> Tuple[Dataset, Dataset, Dataset]:
433
+
434
+ np.random.seed(self.random_seed)
435
+
436
+ indices = np.arange(len(self.full_dataset))
437
+
438
+ if self.stratify and self.full_dataset.labels is not None:
439
+ # Stratified split
440
+ from sklearn.model_selection import train_test_split
441
+
442
+ # First split: train vs (val + test)
443
+ train_idx, temp_idx = train_test_split(
444
+ indices,
445
+ test_size=(self.val_split + self.test_split),
446
+ stratify=[self.full_dataset.labels[i] for i in indices],
447
+ random_state=self.random_seed
448
+ )
449
+
450
+ # Second split: val vs test
451
+ if self.test_split > 0:
452
+ val_idx, test_idx = train_test_split(
453
+ temp_idx,
454
+ test_size=self.test_split / (self.val_split + self.test_split),
455
+ stratify=[self.full_dataset.labels[i] for i in temp_idx],
456
+ random_state=self.random_seed
457
+ )
458
+ else:
459
+ val_idx = temp_idx
460
+ test_idx = np.array([])
461
+
462
+ else:
463
+ # Random split
464
+ np.random.shuffle(indices)
465
+
466
+ train_end = int(self.train_split * len(indices))
467
+ val_end = int((self.train_split + self.val_split) * len(indices))
468
+
469
+ train_idx = indices[:train_end]
470
+ val_idx = indices[train_end:val_end]
471
+ test_idx = indices[val_end:]
472
+
473
+ # Create subset datasets
474
+ train_dataset = self._create_subset(train_idx)
475
+ val_dataset = self._create_subset(val_idx)
476
+ test_dataset = self._create_subset(test_idx)
477
+
478
+ return train_dataset, val_dataset, test_dataset
479
+
480
+ def _create_subset(self, indices: np.ndarray) -> Dataset:
481
+ """Create a subset dataset from indices."""
482
+
483
+ subset_data = [self.full_dataset.processed_data[i] for i in indices]
484
+ subset_labels = None
485
+
486
+ if self.full_dataset.labels is not None:
487
+ if isinstance(self.full_dataset.labels, np.ndarray):
488
+ subset_labels = self.full_dataset.labels[indices]
489
+ else:
490
+ subset_labels = [self.full_dataset.labels[i] for i in indices]
491
+
492
+ # Create new dataset instance
493
+ dataset = HierarchicalVCFDataset.__new__(HierarchicalVCFDataset)
494
+ dataset.tokenizer = self.tokenizer
495
+ dataset.config = self.config
496
+ dataset.processed_data = subset_data
497
+ dataset.labels = subset_labels
498
+ dataset.transform = None
499
+ dataset.target_transform = None
500
+ dataset.cache_processed_data = True
501
+ dataset.stats = dataset._compute_statistics()
502
+
503
+ return dataset
504
+
505
+ def get_dataloaders(self,
506
+ batch_size: int = 16,
507
+ num_workers: int = 0,
508
+ collate_fn: Optional[Callable] = None) -> Tuple[DataLoader, DataLoader, DataLoader]:
509
+ """
510
+ Args:
511
+ batch_size: Batch size for data loading
512
+ num_workers: Number of worker processes
513
+ collate_fn: Custom collate function
514
+
515
+ Returns:
516
+ Tuple of (train_loader, val_loader, test_loader)
517
+ """
518
+
519
+ if collate_fn is None:
520
+ collate_fn = HierarchicalDataCollator(self.tokenizer)
521
+
522
+ train_loader = DataLoader(
523
+ self.train_dataset,
524
+ batch_size=batch_size,
525
+ shuffle=True,
526
+ num_workers=num_workers,
527
+ collate_fn=collate_fn
528
+ )
529
+
530
+ val_loader = DataLoader(
531
+ self.val_dataset,
532
+ batch_size=batch_size,
533
+ shuffle=False,
534
+ num_workers=num_workers,
535
+ collate_fn=collate_fn
536
+ )
537
+
538
+ test_loader = DataLoader(
539
+ self.test_dataset,
540
+ batch_size=batch_size,
541
+ shuffle=False,
542
+ num_workers=num_workers,
543
+ collate_fn=collate_fn
544
+ )
545
+
546
+ return train_loader, val_loader, test_loader
547
+
548
+
549
+ class HuggingFaceDatasetAdapter:
550
+ """
551
+ Convert hierarchical VCF data to Hugging Face Dataset format.
552
+ """
553
+
554
+ def __init__(self, vcf_dataset: HierarchicalVCFDataset):
555
+ self.vcf_dataset = vcf_dataset
556
+
557
+ def to_huggingface_dataset(self) -> DatasetDict:
558
+ """
559
+ Returns:
560
+ HuggingFace DatasetDict
561
+ """
562
+
563
+ # Flatten hierarchical data for HF compatibility
564
+ flattened_data = []
565
+
566
+ for sample in self.vcf_dataset.processed_data:
567
+ sample_id = sample['sample_id']
568
+ encoded_data = sample['encoded_data']
569
+
570
+ # Convert hierarchical structure to flattened format
571
+ flattened_sample = {
572
+ 'sample_id': sample_id,
573
+ 'pathways': list(encoded_data.keys()),
574
+ 'num_pathways': len(encoded_data),
575
+ 'encoded_mutations': self._flatten_mutations(encoded_data)
576
+ }
577
+
578
+ flattened_data.append(flattened_sample)
579
+
580
+ # Add labels if available
581
+ if self.vcf_dataset.labels is not None:
582
+ for i, sample in enumerate(flattened_data):
583
+ sample['label'] = self.vcf_dataset.labels[i]
584
+
585
+ # Create HuggingFace dataset
586
+ hf_dataset = HFDataset.from_list(flattened_data)
587
+
588
+ return DatasetDict({'train': hf_dataset})
589
+
590
+ def _flatten_mutations(self, encoded_data: Dict) -> Dict[str, List]:
591
+ """Flatten hierarchical mutations for HF compatibility."""
592
+
593
+ all_impacts = []
594
+ all_refs = []
595
+ all_alts = []
596
+
597
+ for pathway_token, chromosomes in encoded_data.items():
598
+ for chrom_token, genes in chromosomes.items():
599
+ for gene_token, mutations in genes.items():
600
+ if 'impact' in mutations:
601
+ all_impacts.extend(mutations['impact'])
602
+ if 'ref' in mutations:
603
+ all_refs.extend(mutations['ref'])
604
+ if 'alt' in mutations:
605
+ all_alts.extend(mutations['alt'])
606
+
607
+ return {
608
+ 'impacts': all_impacts,
609
+ 'refs': all_refs,
610
+ 'alts': all_alts
611
+ }
612
+
613
+
614
+ def create_dataset_from_config(config_manager: ConfigManager,
615
+ tokenizer: HierarchicalVCFTokenizer,
616
+ labels: Optional[List] = None) -> HierarchicalVCFDataset:
617
+
618
+ data_config = config_manager.data_config
619
+
620
+ if not data_config.vcf_file_path:
621
+ raise ValueError("VCF file path not specified in configuration")
622
+
623
+ return HierarchicalVCFDataset(
624
+ data_source=data_config.vcf_file_path,
625
+ tokenizer=tokenizer,
626
+ config=data_config,
627
+ labels=labels
628
+ )
629
+
630
+
631
+ def create_data_module_from_config(config_manager: ConfigManager,
632
+ tokenizer: HierarchicalVCFTokenizer,
633
+ labels: Optional[List] = None) -> HierarchicalVCFDataModule:
634
+
635
+ data_config = config_manager.data_config
636
+
637
+ if not data_config.vcf_file_path:
638
+ raise ValueError("VCF file path not specified in configuration")
639
+
640
+ return HierarchicalVCFDataModule(
641
+ data_source=data_config.vcf_file_path,
642
+ tokenizer=tokenizer,
643
+ config=data_config,
644
+ labels=labels
645
+ )
646
+
647
+
648
+ # Utility functions for data preprocessing
649
+ def create_synthetic_labels(dataset: HierarchicalVCFDataset,
650
+ label_type: str = 'random',
651
+ num_classes: int = 2) -> np.ndarray:
652
+ """
653
+ Create synthetic labels for testing purposes.
654
+
655
+ Args:
656
+ dataset: VCF dataset
657
+ label_type: Type of labels ('random', 'mutation_count_based')
658
+ num_classes: Number of classes for classification
659
+
660
+ Returns:
661
+ Array of synthetic labels
662
+ """
663
+
664
+ num_samples = len(dataset)
665
+
666
+ if label_type == 'random':
667
+ return np.random.randint(0, num_classes, size=num_samples)
668
+
669
+ elif label_type == 'mutation_count_based':
670
+ # Create labels based on mutation count thresholds
671
+ mutation_counts = dataset.stats['mutations_per_sample']
672
+ threshold = np.median(mutation_counts)
673
+
674
+ labels = []
675
+ for count in mutation_counts:
676
+ if num_classes == 2:
677
+ labels.append(1 if count > threshold else 0)
678
+ else:
679
+ # Divide into quantiles
680
+ percentiles = np.linspace(0, 100, num_classes + 1)
681
+ thresholds = np.percentile(mutation_counts, percentiles[1:-1])
682
+
683
+ label = 0
684
+ for i, t in enumerate(thresholds):
685
+ if count > t:
686
+ label = i + 1
687
+ else:
688
+ break
689
+ labels.append(label)
690
+
691
+ return np.array(labels)
692
+
693
+ else:
694
+ raise ValueError(f"Unknown label type: {label_type}")
695
+
696
+
697
+ # Example usage and testing
698
+ if __name__ == "__main__":
699
+ from tokenizer import create_tokenizer_from_config
700
+
701
+ # Example usage
702
+ config_manager = ConfigManager()
703
+ config_manager.data_config.vcf_file_path = "example_data.json"
704
+
705
+ # Create tokenizer
706
+ tokenizer = create_tokenizer_from_config(config_manager)
707
+
708
+ # Example data
709
+ example_data = {
710
+ 'sample1': {
711
+ 'pathway1': {
712
+ 'chr1': {
713
+ 'gene1': [
714
+ {'impact': 'HIGH', 'reference': 'A', 'alternate': 'T'},
715
+ {'impact': 'MODERATE', 'reference': 'G', 'alternate': 'C'}
716
+ ]
717
+ }
718
+ }
719
+ },
720
+ 'sample2': {
721
+ 'pathway2': {
722
+ 'chr2': {
723
+ 'gene2': [
724
+ {'impact': 'LOW', 'reference': 'T', 'alternate': 'A'}
725
+ ]
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ # Build tokenizer vocabulary
732
+ tokenizer.build_vocabulary(example_data)
733
+
734
+ # Create dataset
735
+ dataset = HierarchicalVCFDataset(
736
+ data_source=example_data,
737
+ tokenizer=tokenizer
738
+ )
739
+
740
+ # Create synthetic labels
741
+ labels = create_synthetic_labels(dataset, label_type='random', num_classes=2)
742
+ dataset.labels = labels
743
+
744
+ # Create data module
745
+ data_module = HierarchicalVCFDataModule(
746
+ data_source=example_data,
747
+ tokenizer=tokenizer,
748
+ labels=labels,
749
+ train_split=0.6,
750
+ val_split=0.2,
751
+ test_split=0.2
752
+ )
753
+
754
+ # Get data loaders
755
+ train_loader, val_loader, test_loader = data_module.get_dataloaders(batch_size=2)
756
+
757
+ # Test data loading
758
+ for batch in train_loader:
759
+ print(f"Batch size: {batch['batch_size']}")
760
+ print(f"Sample IDs: {[s.get('sample_id', 'N/A') for s in batch['samples']]}")
761
+ break
762
+
763
+ print(f"Dataset statistics: {dataset.get_statistics()}")