abd-ur commited on
Commit
3dec9cb
·
verified ·
1 Parent(s): cc67f2f

Create tokenizer.py

Browse files

to be developed further

Files changed (1) hide show
  1. tokenizer.py +512 -0
tokenizer.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tokenization for VCF data with support for hierarchical structures
3
+ """
4
+
5
+ import json
6
+ import pickle
7
+ import logging
8
+ from pathlib import Path
9
+ from collections import defaultdict, Counter
10
+ from typing import Dict, List, Tuple, Optional, Union, Any
11
+ import numpy as np
12
+
13
+ from transformers import PreTrainedTokenizer
14
+ from transformers.tokenization_utils import AddedToken
15
+
16
+ from config import DataConfig, ConfigManager
17
+ from parser import MutationRecord
18
+
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class HierarchicalVCFTokenizer(PreTrainedTokenizer):
26
+
27
+ vocab_files_names = {
28
+ "vocab_file": "vocab.json",
29
+ "mutation_vocab_file": "mutation_vocab.json"
30
+ }
31
+
32
+ def __init__(self,
33
+ vocab_file: Optional[str] = None,
34
+ mutation_vocab_file: Optional[str] = None,
35
+ config: Optional[DataConfig] = None,
36
+ **kwargs):
37
+
38
+ # Initialize special tokens
39
+ self.config = config or DataConfig()
40
+
41
+ # Set up special tokens
42
+ special_tokens = self.config.special_tokens
43
+ pad_token = special_tokens.get("pad_token", "[PAD]")
44
+ unk_token = special_tokens.get("unk_token", "[UNK]")
45
+ sep_token = special_tokens.get("sep_token", "[SEP]")
46
+ cls_token = special_tokens.get("cls_token", "[CLS]")
47
+
48
+ super().__init__(
49
+ pad_token=pad_token,
50
+ unk_token=unk_token,
51
+ sep_token=sep_token,
52
+ cls_token=cls_token,
53
+ **kwargs
54
+ )
55
+
56
+ # Initialize vocabularies for different mutation fields
57
+ self.mutation_fields = ['impact', 'ref', 'alt', 'chromosome', 'pathway', 'gene']
58
+ self.field_vocabs = {}
59
+
60
+ # Initialize vocabularies
61
+ self._initialize_vocabularies()
62
+
63
+ # Load existing vocabularies if provided
64
+ if vocab_file and Path(vocab_file).exists():
65
+ self.load_vocabulary(vocab_file)
66
+
67
+ if mutation_vocab_file and Path(mutation_vocab_file).exists():
68
+ self.load_mutation_vocabulary(mutation_vocab_file)
69
+
70
+ # Statistics
71
+ self.tokenization_stats = {
72
+ 'total_samples': 0,
73
+ 'total_mutations': 0,
74
+ 'vocab_sizes': {}
75
+ }
76
+
77
+ def _initialize_vocabularies(self) -> None:
78
+ for field in self.mutation_fields:
79
+ self.field_vocabs[field] = {
80
+ self.pad_token: 0,
81
+ self.unk_token: 1,
82
+ self.sep_token: 2,
83
+ self.cls_token: 3
84
+ }
85
+
86
+ # Add common genomic tokens
87
+ self._add_common_genomic_tokens()
88
+
89
+ def _add_common_genomic_tokens(self) -> None:
90
+ """To be made scalable and dynamic"""
91
+ # Common impact values
92
+ common_impacts = ["HIGH", "MODERATE", "LOW", "MODIFIER"]
93
+ for impact in common_impacts:
94
+ if impact not in self.field_vocabs['impact']:
95
+ self.field_vocabs['impact'][impact] = len(self.field_vocabs['impact'])
96
+
97
+ # Common nucleotides
98
+ nucleotides = ["A", "T", "G", "C", "N", "-"]
99
+ for nt in nucleotides:
100
+ for field in ['ref', 'alt']:
101
+ if nt not in self.field_vocabs[field]:
102
+ self.field_vocabs[field][nt] = len(self.field_vocabs[field])
103
+
104
+ # Common chromosomes
105
+ chromosomes = [str(i) for i in range(1, 23)] + ["X", "Y", "MT"]
106
+ for chrom in chromosomes:
107
+ if chrom not in self.field_vocabs['chromosome']:
108
+ self.field_vocabs['chromosome'][chrom] = len(self.field_vocabs['chromosome'])
109
+
110
+ def build_vocabulary(self, hierarchical_data: Dict[str, Any]) -> None:
111
+ """
112
+ Args:
113
+ hierarchical_data: Parsed VCF data structure
114
+ """
115
+ logger.info("Building vocabularies from hierarchical data...")
116
+
117
+ vocab_counters = {field: Counter() for field in self.mutation_fields}
118
+
119
+ for sample_id, pathways in hierarchical_data.items():
120
+ for pathway_id, chromosomes in pathways.items():
121
+ # Count pathway occurrences
122
+ vocab_counters['pathway'][pathway_id] += 1
123
+
124
+ for chrom_id, genes in chromosomes.items():
125
+ # Count chromosome occurrences
126
+ vocab_counters['chromosome'][chrom_id] += 1
127
+
128
+ for gene_id, mutations in genes.items():
129
+ # Count gene occurrences
130
+ vocab_counters['gene'][gene_id] += 1
131
+
132
+ for mutation in mutations:
133
+ if isinstance(mutation, MutationRecord):
134
+ # Count mutation field values
135
+ vocab_counters['impact'][mutation.impact] += 1
136
+ vocab_counters['ref'][mutation.reference] += 1
137
+ vocab_counters['alt'][mutation.alternate] += 1
138
+ elif isinstance(mutation, dict):
139
+ # Handle dictionary format
140
+ vocab_counters['impact'][mutation.get('impact', self.unk_token)] += 1
141
+ vocab_counters['ref'][mutation.get('reference', self.unk_token)] += 1
142
+ vocab_counters['alt'][mutation.get('alternate', self.unk_token)] += 1
143
+
144
+ # Build vocabularies from counters
145
+ for field, counter in vocab_counters.items():
146
+ for token, count in counter.most_common():
147
+ if token and token not in self.field_vocabs[field]:
148
+ self.field_vocabs[field][token] = len(self.field_vocabs[field])
149
+
150
+ # Update statistics
151
+ self.tokenization_stats['vocab_sizes'] = {
152
+ field: len(vocab) for field, vocab in self.field_vocabs.items()
153
+ }
154
+
155
+ logger.info(f"Vocabulary sizes: {self.tokenization_stats['vocab_sizes']}")
156
+
157
+ def encode_hierarchical_sample(self, sample_data: Dict[str, Any]) -> Dict[str, Any]:
158
+ """
159
+ Encode a single hierarchical sample into tokenized format.
160
+ Args:
161
+ sample_data: Single sample from hierarchical data
162
+ Returns:
163
+ Encoded sample with tokenized values
164
+ """
165
+ encoded_sample = {}
166
+
167
+ for pathway_id, chromosomes in sample_data.items():
168
+ # Tokenize pathway ID
169
+ pathway_token = self.field_vocabs['pathway'].get(
170
+ pathway_id, self.field_vocabs['pathway'][self.unk_token]
171
+ )
172
+
173
+ encoded_sample[pathway_token] = {}
174
+
175
+ for chrom_id, genes in chromosomes.items():
176
+ # Tokenize chromosome ID
177
+ chrom_token = self.field_vocabs['chromosome'].get(
178
+ chrom_id, self.field_vocabs['chromosome'][self.unk_token]
179
+ )
180
+
181
+ encoded_sample[pathway_token][chrom_token] = {}
182
+
183
+ for gene_id, mutations in genes.items():
184
+ # Tokenize gene ID
185
+ gene_token = self.field_vocabs['gene'].get(
186
+ gene_id, self.field_vocabs['gene'][self.unk_token]
187
+ )
188
+
189
+ # Encode mutations
190
+ encoded_mutations = self._encode_mutations(mutations)
191
+ encoded_sample[pathway_token][chrom_token][gene_token] = encoded_mutations
192
+
193
+ return encoded_sample
194
+
195
+ def _encode_mutations(self, mutations: List[Union[MutationRecord, Dict]]) -> Dict[str, List[int]]:
196
+ encoded_mutations = {
197
+ 'impact': [],
198
+ 'ref': [],
199
+ 'alt': []
200
+ }
201
+
202
+ for mutation in mutations:
203
+ if isinstance(mutation, MutationRecord):
204
+ impact = mutation.impact
205
+ ref = mutation.reference
206
+ alt = mutation.alternate
207
+ elif isinstance(mutation, dict):
208
+ impact = mutation.get('impact', self.unk_token)
209
+ ref = mutation.get('reference', self.unk_token)
210
+ alt = mutation.get('alternate', self.unk_token)
211
+ else:
212
+ continue
213
+
214
+ # Tokenize each field
215
+ encoded_mutations['impact'].append(
216
+ self.field_vocabs['impact'].get(impact, self.field_vocabs['impact'][self.unk_token])
217
+ )
218
+ encoded_mutations['ref'].append(
219
+ self.field_vocabs['ref'].get(ref, self.field_vocabs['ref'][self.unk_token])
220
+ )
221
+ encoded_mutations['alt'].append(
222
+ self.field_vocabs['alt'].get(alt, self.field_vocabs['alt'][self.unk_token])
223
+ )
224
+
225
+ return encoded_mutations
226
+
227
+ def encode_batch(self, batch_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
228
+ """
229
+ Encode a batch of hierarchical samples.
230
+ Args:
231
+ batch_data: List of sample dictionaries
232
+ Returns:
233
+ List of encoded samples
234
+ """
235
+ encoded_batch = []
236
+
237
+ for sample_data in batch_data:
238
+ encoded_sample = self.encode_hierarchical_sample(sample_data)
239
+ encoded_batch.append(encoded_sample)
240
+
241
+ self.tokenization_stats['total_samples'] += len(batch_data)
242
+
243
+ return encoded_batch
244
+
245
+ def decode_tokens(self, field: str, token_ids: List[int]) -> List[str]:
246
+ """
247
+ Decode token IDs back to original values.
248
+ Args:
249
+ field: Field name ('impact', 'ref', 'alt', etc.)
250
+ token_ids: List of token IDs
251
+ Returns:
252
+ List of decoded tokens
253
+ """
254
+ if field not in self.field_vocabs:
255
+ raise ValueError(f"Unknown field: {field}")
256
+
257
+ id_to_token = {v: k for k, v in self.field_vocabs[field].items()}
258
+ return [id_to_token.get(token_id, self.unk_token) for token_id in token_ids]
259
+
260
+ def get_vocab_size(self, field: str) -> int:
261
+ """Get vocabulary size for a specific field."""
262
+ if field not in self.field_vocabs:
263
+ raise ValueError(f"Unknown field: {field}")
264
+ return len(self.field_vocabs[field])
265
+
266
+ def get_all_vocab_sizes(self) -> Dict[str, int]:
267
+ """Get vocabulary sizes for all fields."""
268
+ return {field: len(vocab) for field, vocab in self.field_vocabs.items()}
269
+
270
+ def save_vocabulary(self, save_directory: Union[str, Path], filename_prefix: Optional[str] = None) -> Tuple[str, ...]:
271
+ """
272
+ Args:
273
+ save_directory: Directory to save vocabularies
274
+ filename_prefix: Optional prefix for filenames
275
+
276
+ Returns:
277
+ Tuple of saved file paths
278
+ """
279
+ save_directory = Path(save_directory)
280
+ save_directory.mkdir(parents=True, exist_ok=True)
281
+
282
+ prefix = f"{filename_prefix}_" if filename_prefix else ""
283
+
284
+ # Save mutation vocabularies
285
+ mutation_vocab_file = save_directory / f"{prefix}mutation_vocab.json"
286
+ with open(mutation_vocab_file, 'w') as f:
287
+ json.dump(self.field_vocabs, f, indent=2)
288
+
289
+ # Save tokenizer configuration
290
+ config_file = save_directory / f"{prefix}tokenizer_config.json"
291
+ config_data = {
292
+ 'tokenizer_class': self.__class__.__name__,
293
+ 'special_tokens': {
294
+ 'pad_token': self.pad_token,
295
+ 'unk_token': self.unk_token,
296
+ 'sep_token': self.sep_token,
297
+ 'cls_token': self.cls_token
298
+ },
299
+ 'vocab_sizes': self.get_all_vocab_sizes(),
300
+ 'mutation_fields': self.mutation_fields
301
+ }
302
+
303
+ with open(config_file, 'w') as f:
304
+ json.dump(config_data, f, indent=2)
305
+
306
+ logger.info(f"Vocabularies saved to {save_directory}")
307
+
308
+ return str(mutation_vocab_file), str(config_file)
309
+
310
+ def load_vocabulary(self, vocab_file: Union[str, Path]) -> None:
311
+ vocab_file = Path(vocab_file)
312
+
313
+ if not vocab_file.exists():
314
+ raise FileNotFoundError(f"Vocabulary file not found: {vocab_file}")
315
+
316
+ with open(vocab_file, 'r') as f:
317
+ vocab_data = json.load(f)
318
+
319
+ # Update vocabularies
320
+ for field, vocab in vocab_data.items():
321
+ if field in self.mutation_fields:
322
+ self.field_vocabs[field] = vocab
323
+
324
+ logger.info(f"Vocabularies loaded from {vocab_file}")
325
+
326
+ def load_mutation_vocabulary(self, mutation_vocab_file: Union[str, Path]) -> None:
327
+ """Load mutation-specific vocabularies from file."""
328
+ self.load_vocabulary(mutation_vocab_file)
329
+
330
+ def create_padding_masks(self, encoded_sample: Dict[str, Any], max_lengths: Dict[str, int]) -> Dict[str, Any]:
331
+ """
332
+ Create padding masks for hierarchical data.
333
+ Args:
334
+ encoded_sample: Encoded sample data
335
+ max_lengths: Maximum lengths for each level
336
+ Returns:
337
+ Sample with padding masks
338
+ """
339
+ masked_sample = {}
340
+
341
+ for pathway_token, chromosomes in encoded_sample.items():
342
+ masked_sample[pathway_token] = {}
343
+
344
+ for chrom_token, genes in chromosomes.items():
345
+ masked_sample[pathway_token][chrom_token] = {}
346
+
347
+ for gene_token, mutations in genes.items():
348
+ masked_mutations = {}
349
+
350
+ for field, token_list in mutations.items():
351
+ max_len = max_lengths.get(f'mutations_{field}', 100)
352
+
353
+ # Pad or truncate
354
+ if len(token_list) < max_len:
355
+ padded_list = token_list + [self.field_vocabs[field][self.pad_token]] * (max_len - len(token_list))
356
+ mask = [1] * len(token_list) + [0] * (max_len - len(token_list))
357
+ else:
358
+ padded_list = token_list[:max_len]
359
+ mask = [1] * max_len
360
+
361
+ masked_mutations[field] = {
362
+ 'tokens': padded_list,
363
+ 'mask': mask
364
+ }
365
+
366
+ masked_sample[pathway_token][chrom_token][gene_token] = masked_mutations
367
+
368
+ return masked_sample
369
+
370
+ def get_tokenization_statistics(self) -> Dict[str, Any]:
371
+ stats = self.tokenization_stats.copy()
372
+ stats['vocab_sizes'] = self.get_all_vocab_sizes()
373
+ return stats
374
+
375
+ # Hugging Face compatibility methods
376
+ @property
377
+ def vocab_size(self) -> int:
378
+ return sum(len(vocab) for vocab in self.field_vocabs.values())
379
+
380
+ def get_vocab(self) -> Dict[str, int]:
381
+ combined_vocab = {}
382
+ offset = 0
383
+
384
+ for field, vocab in self.field_vocabs.items():
385
+ for token, idx in vocab.items():
386
+ combined_vocab[f"{field}:{token}"] = idx + offset
387
+ offset += len(vocab)
388
+
389
+ return combined_vocab
390
+
391
+ def _tokenize(self, text: str) -> List[str]:
392
+ # This is a simplified implementation for compatibility
393
+ # In practice, hierarchical data should be processed differently
394
+ return text.split()
395
+
396
+ def _convert_token_to_id(self, token: str) -> int:
397
+ # Parse field:token format
398
+ if ':' in token:
399
+ field, actual_token = token.split(':', 1)
400
+ if field in self.field_vocabs:
401
+ return self.field_vocabs[field].get(actual_token, self.field_vocabs[field][self.unk_token])
402
+
403
+ return self.field_vocabs.get('impact', {}).get(self.unk_token, 1)
404
+
405
+ def _convert_id_to_token(self, index: int) -> str:
406
+ # This is a simplified reverse lookup
407
+ for field, vocab in self.field_vocabs.items():
408
+ id_to_token = {v: k for k, v in vocab.items()}
409
+ if index in id_to_token:
410
+ return f"{field}:{id_to_token[index]}"
411
+
412
+ return self.unk_token
413
+
414
+
415
+ class HierarchicalDataCollator:
416
+
417
+ def __init__(self, tokenizer: HierarchicalVCFTokenizer, max_lengths: Optional[Dict[str, int]] = None):
418
+ self.tokenizer = tokenizer
419
+ self.max_lengths = max_lengths or {
420
+ 'mutations_impact': 50,
421
+ 'mutations_ref': 50,
422
+ 'mutations_alt': 50,
423
+ 'genes_per_chromosome': 100,
424
+ 'chromosomes_per_pathway': 25,
425
+ 'pathways_per_sample': 50
426
+ }
427
+
428
+ def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
429
+ """
430
+ Collate batch of hierarchical samples.
431
+ Args:
432
+ batch: List of encoded hierarchical samples
433
+ Returns:
434
+ Collated batch ready for model input
435
+ """
436
+ collated_batch = {
437
+ 'samples': [],
438
+ 'batch_size': len(batch),
439
+ 'metadata': {
440
+ 'num_pathways': [],
441
+ 'num_chromosomes': [],
442
+ 'num_genes': [],
443
+ 'num_mutations': []
444
+ }
445
+ }
446
+
447
+ for sample in batch:
448
+ # Create padding masks
449
+ masked_sample = self.tokenizer.create_padding_masks(sample, self.max_lengths)
450
+ collated_batch['samples'].append(masked_sample)
451
+
452
+ # Collect metadata
453
+ num_pathways = len(sample)
454
+ num_chromosomes = sum(len(chroms) for chroms in sample.values())
455
+ num_genes = sum(
456
+ len(genes) for chroms in sample.values()
457
+ for genes in chroms.values()
458
+ )
459
+ num_mutations = sum(
460
+ len(mutations.get('impact', []))
461
+ for chroms in sample.values()
462
+ for genes in chroms.values()
463
+ for mutations in genes.values()
464
+ )
465
+
466
+ collated_batch['metadata']['num_pathways'].append(num_pathways)
467
+ collated_batch['metadata']['num_chromosomes'].append(num_chromosomes)
468
+ collated_batch['metadata']['num_genes'].append(num_genes)
469
+ collated_batch['metadata']['num_mutations'].append(num_mutations)
470
+
471
+ return collated_batch
472
+
473
+
474
+ def create_tokenizer_from_config(config_manager: ConfigManager) -> HierarchicalVCFTokenizer:
475
+ """Create tokenizer from configuration manager."""
476
+ return HierarchicalVCFTokenizer(config=config_manager.data_config)
477
+
478
+
479
+ # Example usage and testing
480
+ if __name__ == "__main__":
481
+ # Example usage
482
+ config_manager = ConfigManager()
483
+ tokenizer = create_tokenizer_from_config(config_manager)
484
+
485
+ # Example hierarchical data structure
486
+ example_data = {
487
+ 'sample1': {
488
+ 'pathway1': {
489
+ 'chr1': {
490
+ 'gene1': [
491
+ {
492
+ 'impact': 'HIGH',
493
+ 'reference': 'A',
494
+ 'alternate': 'T'
495
+ }
496
+ ]
497
+ }
498
+ }
499
+ }
500
+ }
501
+
502
+ # Build vocabulary
503
+ tokenizer.build_vocabulary({'sample1': example_data['sample1']})
504
+
505
+ # Encode sample
506
+ encoded = tokenizer.encode_hierarchical_sample(example_data['sample1'])
507
+ print(f"Encoded sample: {encoded}")
508
+
509
+ # Save vocabulary
510
+ tokenizer.save_vocabulary("./tokenizer_files")
511
+
512
+ print(f"Tokenization statistics: {tokenizer.get_tokenization_statistics()}")