File size: 23,389 Bytes
1d6f391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
"""
Standalone WURCS Tokenizer for Fine-tuning

This tokenizer matches the v3.1 pre-training tokenization exactly:
- Uses bracket-counting parser (handles '/' in modifications)
- 167 tokens (including 58 whole-token modifications)
- Same special tokens: [PAD], [UNK], [START], [END], [BRANCH_OPEN], [BRANCH_CLOSE], etc.
- Detects and inserts branch tokens based on WURCS topology

Usage:
    tokenizer = WURCSTokenizer('data/vocabulary.json')
    result = tokenizer.tokenize(wurcs_string)
    # result = {'tokens': [...], 'token_ids': [...], 'length': N}
"""

import json
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from collections import defaultdict


def parse_wurcs_sections(wurcs: str) -> List[str]:
    """
    Parse WURCS string into sections by counting brackets.
    
    This correctly handles modifications with '/' characters like NCC/3=O.
    The simple regex approach fails because '/' appears inside modifications.
    
    Args:
        wurcs: WURCS string
    
    Returns:
        List of sections [version, counts, residues, topology, linkages]
    """
    sections = []
    current_section = []
    bracket_depth = 0
    
    for char in wurcs:
        if char == '[':
            bracket_depth += 1
            current_section.append(char)
        elif char == ']':
            bracket_depth -= 1
            current_section.append(char)
        elif char == '/' and bracket_depth == 0:
            # This '/' is a section separator (not inside a modification)
            sections.append(''.join(current_section))
            current_section = []
        else:
            current_section.append(char)
    
    # Add last section
    if current_section:
        sections.append(''.join(current_section))
    
    return sections


def extract_residues_from_wurcs(wurcs: str) -> List[str]:
    """
    Extract residue definitions from WURCS string.
    
    Uses bracket-counting parser to correctly handle modifications
    containing '/' characters (e.g., 2*NCC/3=O, 4*OSO/3=O/3=O).
    
    Args:
        wurcs: WURCS string
    
    Returns:
        List of residue strings (e.g., ['a2122h-1a_1-5_2*NCC/3=O', ...])
    """
    sections = parse_wurcs_sections(wurcs)
    
    if len(sections) < 3:
        return []
    
    # Section 2 contains residue definitions: [type1][type2]...
    residues_section = sections[2]
    
    # Find all residues [...]
    residues = re.findall(r'\[([^\]]+)\]', residues_section)
    
    return residues


def extract_linkages_from_wurcs(wurcs: str) -> List[str]:
    """
    Extract linkage patterns from WURCS string.
    
    Uses bracket-counting parser to correctly handle modifications
    containing '/' characters.
    
    Args:
        wurcs: WURCS string
    
    Returns:
        List of linkage strings (e.g., ['a4-b1', 'b3-c2'])
    """
    sections = parse_wurcs_sections(wurcs)
    
    if len(sections) < 5:
        return []
    
    # Section 4 contains linkages: a4-b1_b3-c2_...
    linkage_section = sections[4]
    
    # Split by underscore
    linkages = linkage_section.split('_')
    
    return [l for l in linkages if l]


def detect_branch_points(linkages: List[str]) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
    """
    Detect branch points from linkages.
    A residue is a branch point if it has >1 outgoing connections.
    
    Args:
        linkages: List of linkage strings (e.g., ['a4-b1', 'a3-c2'])
    
    Returns:
        Tuple of:
            - branch_points: Dict mapping residue letter to list of children (only for >1 children)
            - outgoing: Dict mapping all residue letters to their children
    """
    outgoing = defaultdict(list)
    
    for link in linkages:
        # Parse link like "a3-b1" or "a?-b1"
        match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link)
        if match:
            from_res, from_pos, to_res, to_pos = match.groups()
            outgoing[from_res].append(to_res)
    
    # Branch points have >1 children
    branch_points = {k: v for k, v in outgoing.items() if len(v) > 1}
    return branch_points, dict(outgoing)


def compute_residue_depths(linkages: List[str]) -> Dict[str, int]:
    """
    Compute depth of each residue in the glycan tree using BFS from root.
    
    Root is typically 'a' (first residue). Depth 0 = root, 1 = children of root, etc.
    
    Args:
        linkages: List of linkage strings (e.g., ['a4-b1', 'b3-c1'])
    
    Returns:
        Dict mapping residue letter to its depth in the tree
    """
    # Build parent-child graph from linkages
    children = defaultdict(list)
    all_residues = set()
    
    for link in linkages:
        match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link)
        if match:
            from_res, _, to_res, _ = match.groups()
            children[from_res].append(to_res)
            all_residues.add(from_res)
            all_residues.add(to_res)
    
    # Find root (residue with no parent) - usually 'a'
    child_set = set()
    for kids in children.values():
        child_set.update(kids)
    
    roots = all_residues - child_set
    root = min(roots) if roots else 'a'  # Default to 'a' if can't determine
    
    # BFS from root to compute depths
    depths = {root: 0}
    queue = [root]
    
    while queue:
        current = queue.pop(0)
        for child in children.get(current, []):
            if child not in depths:
                depths[child] = depths[current] + 1
                queue.append(child)
    
    return depths


def parse_linkage_type(link: str) -> int:
    """
    Parse a linkage string to extract linkage type ID.
    
    Linkage types based on anomeric configuration and position:
    - α1-2, α1-3, α1-4, α1-6, β1-2, β1-3, β1-4, β1-6, unknown
    
    WURCS uses position numbers (1-6) in linkage patterns like 'a4-b1'.
    We extract the positions to determine linkage type.
    
    Args:
        link: Linkage string like 'a4-b1' or 'a3-b1'
    
    Returns:
        Linkage type ID (0-8):
        0 = 1-2, 1 = 1-3, 2 = 1-4, 3 = 1-6, 4 = 2-3, 5 = 2-6, 6 = 3-6, 7 = other, 8 = unknown
    """
    LINKAGE_TYPES = {
        (1, 2): 0, (2, 1): 0,
        (1, 3): 1, (3, 1): 1,
        (1, 4): 2, (4, 1): 2,
        (1, 6): 3, (6, 1): 3,
        (2, 3): 4, (3, 2): 4,
        (2, 6): 5, (6, 2): 5,
        (3, 6): 6, (6, 3): 6,
    }
    
    match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link)
    if match:
        _, from_pos, _, to_pos = match.groups()
        try:
            pos_tuple = (int(from_pos), int(to_pos))
            return LINKAGE_TYPES.get(pos_tuple, 7)  # 7 = other
        except ValueError:
            return 8  # Unknown (contains '?')
    return 8


class WURCSTokenizer:
    """
    WURCS tokenizer matching v3.1 pre-training exactly.
    
    Key features:
    - Bracket-counting parser for '/' handling in modifications
    - Branch token insertion based on topology
    - Whole-token modifications (58 total)
    """
    
    def __init__(self, vocab_path: str):
        """
        Initialize tokenizer with vocabulary.
        
        Args:
            vocab_path: Path to vocabulary.json (same as used in v3.1 pre-training)
        """
        with open(vocab_path, 'r') as f:
            self.vocab = json.load(f)
        
        # Build token_to_id mapping
        self.token_to_id = {}
        
        # Add special tokens
        for token, idx in self.vocab.get('special_tokens', {}).items():
            self.token_to_id[token] = idx
        
        # Add skeleton atoms
        for token, idx in self.vocab.get('skeleton_atoms', {}).items():
            self.token_to_id[token] = idx
        
        # Add linkage atoms
        for token, idx in self.vocab.get('linkage_atoms', {}).items():
            self.token_to_id[token] = idx
        
        # Add anomeric symbols
        for token, idx in self.vocab.get('anomeric_symbols', {}).items():
            self.token_to_id[token] = idx
        
        # Add anomeric positions
        for token, idx in self.vocab.get('anomeric_positions', {}).items():
            self.token_to_id[token] = idx
        
        # Add whole modifications
        for token, idx in self.vocab.get('whole_modifications', {}).items():
            self.token_to_id[token] = idx
        
        # Build reverse mapping
        self.id_to_token = {idx: token for token, idx in self.token_to_id.items()}
        
        # Set special token IDs
        self.pad_token_id = self.token_to_id.get('[PAD]', 0)
        self.unk_token_id = self.token_to_id.get('[UNK]', 1)
        self.start_token_id = self.token_to_id.get('[START]', 2)
        self.end_token_id = self.token_to_id.get('[END]', 3)
        self.branch_open_id = self.token_to_id.get('[BRANCH_OPEN]', 5)
        self.branch_close_id = self.token_to_id.get('[BRANCH_CLOSE]', 6)
        self.mod_token_id = self.token_to_id.get('[MOD]', 9)
        self.link_token_id = self.token_to_id.get('[LINK]', 7)
        
        self.vocab_size = self.vocab.get('metadata', {}).get('total_vocab_size', len(self.token_to_id))
    
    def tokenize(self, wurcs: str, max_length: int = 512) -> Dict:
        """
        Tokenize a WURCS string.
        
        Matches v3.1 pre-training tokenization exactly:
        - Bracket-counting parser for '/' handling
        - Branch token insertion based on topology
        - Whole-token modifications
        
        Args:
            wurcs: WURCS string to tokenize
            max_length: Maximum sequence length (truncate if longer)
        
        Returns:
            Dict with:
                - tokens: List of token strings
                - token_ids: List of token IDs
                - residue_ids: List of residue IDs for each token (-1=special, -2=linkage)
                - monosaccharide_names: List of monosaccharide names for each residue
                - length: Number of tokens
                - attention_mask: 1 for real tokens, 0 for padding
                - is_branched: Whether the glycan has branches
        """
        all_tokens = []
        all_token_ids = []
        all_residue_ids = []  # Track which residue each token belongs to
        all_branch_depths = []  # NEW: Branch depth for each token (0=root, 1=child, etc.)
        all_linkage_types = []  # NEW: Linkage type for each token (0-8)
        monosaccharide_names = []  # Names of monosaccharides in order
        is_branched = False
        residue_depths = {}  # Will be computed from linkages
        
        # Common monosaccharide patterns based on WURCS skeleton
        MONO_PATTERNS = {
            'a2122h': 'Glc', 'a2112h': 'Gal', 'a1221m': 'Fuc', 'a2211m': 'Rha',
            'a212h': 'Xyl', 'a21d2h': 'Man', 'a2112m': 'Ara', 'a2d21h': 'Ido',
            'axxxxh': 'Hex', 'Aad21122h': 'Neu5Ac', 'Aad21122h-2': 'Neu5Gc',
        }
        
        # Start token (residue_id = -1 for special tokens, branch_depth = 0)
        all_tokens.append('[START]')
        all_token_ids.append(self.start_token_id)
        all_residue_ids.append(-1)
        all_branch_depths.append(0)  # Special tokens at depth 0
        all_linkage_types.append(0)  # No linkage for special tokens
        
        try:
            # Extract residues and linkages
            residues = extract_residues_from_wurcs(wurcs)
            linkages = extract_linkages_from_wurcs(wurcs)
            
            # Detect branching and compute residue depths
            branch_points, outgoing = detect_branch_points(linkages)
            is_branched = len(branch_points) > 0
            
            # NEW: Compute depth of each residue in the tree
            residue_depths = compute_residue_depths(linkages)
            
            # Get residue order from topology
            sections = parse_wurcs_sections(wurcs)
            topology = sections[3] if len(sections) > 3 else ""
            residue_order = [int(x) for x in topology.split('-') if x.isdigit()]
            
            # Tokenize each residue in order
            residue_letters = 'abcdefghijklmnopqrstuvwxyz'
            current_residue_id = 0
            
            for i, res_idx in enumerate(residue_order):
                if res_idx < 1 or res_idx > len(residues):
                    continue
                
                residue = residues[res_idx - 1]
                res_letter = residue_letters[i] if i < len(residue_letters) else 'z'
                
                # Determine monosaccharide name from skeleton
                mono_name = '<UNK>'
                parts = residue.split('_')
                if parts:
                    skeleton = parts[0].split('-')[0] if '-' in parts[0] else parts[0]
                    # Check for modifications (GlcNAc, GalNAc, etc.)
                    has_nac = any('NCC/3=O' in p or 'NAc' in p for p in parts)
                    has_s = any('OSO' in p for p in parts)
                    
                    for pattern, name in MONO_PATTERNS.items():
                        if skeleton.startswith(pattern) or skeleton == pattern:
                            mono_name = name
                            break
                    
                    # Adjust for modifications
                    if has_nac and mono_name in ['Glc', 'Gal']:
                        mono_name = mono_name + 'NAc'
                    elif has_s:
                        mono_name = mono_name + 'S'
                
                monosaccharide_names.append(mono_name)
                
                # Check if this is a branch point - add [BRANCH_OPEN]
                # Get depth for this residue
                res_depth = residue_depths.get(res_letter, 0)
                
                if res_letter in branch_points:
                    all_tokens.append('[BRANCH_OPEN]')
                    all_token_ids.append(self.branch_open_id)
                    all_residue_ids.append(-1)  # Branch tokens are special
                    all_branch_depths.append(res_depth)
                    all_linkage_types.append(0)
                
                # Parse residue: skeleton-anomer_ring_mods
                # Part 0: Skeleton and anomer (e.g., "a2122h-1b")
                if parts:
                    main_part = parts[0]
                    for char in main_part:
                        if char in self.token_to_id:
                            all_tokens.append(char)
                            all_token_ids.append(self.token_to_id[char])
                            all_residue_ids.append(current_residue_id)
                            all_branch_depths.append(res_depth)
                            all_linkage_types.append(0)  # Residue tokens have no linkage type
                        elif char == '-':
                            if '-' in self.token_to_id:
                                all_tokens.append('-')
                                all_token_ids.append(self.token_to_id['-'])
                                all_residue_ids.append(current_residue_id)
                                all_branch_depths.append(res_depth)
                                all_linkage_types.append(0)
                        else:
                            all_tokens.append('[UNK]')
                            all_token_ids.append(self.unk_token_id)
                            all_residue_ids.append(current_residue_id)
                            all_branch_depths.append(res_depth)
                            all_linkage_types.append(0)
                
                # Parts 2+: Modifications (skip part 1 which is ring closure)
                if len(parts) > 2:
                    for part in parts[2:]:
                        if '*' in part:
                            all_tokens.append('[MOD]')
                            all_token_ids.append(self.mod_token_id)
                            all_residue_ids.append(current_residue_id)
                            all_branch_depths.append(res_depth)
                            all_linkage_types.append(0)
                            
                            if part in self.token_to_id:
                                all_tokens.append(part)
                                all_token_ids.append(self.token_to_id[part])
                                all_residue_ids.append(current_residue_id)
                                all_branch_depths.append(res_depth)
                                all_linkage_types.append(0)
                            else:
                                all_tokens.append('[UNK_MOD]')
                                all_token_ids.append(self.token_to_id.get('[UNK_MOD]', self.unk_token_id))
                                all_residue_ids.append(current_residue_id)
                                all_branch_depths.append(res_depth)
                                all_linkage_types.append(0)
                
                current_residue_id += 1
            
            # Add branch close tokens at the end
            branch_opens = all_tokens.count('[BRANCH_OPEN]')
            for _ in range(branch_opens):
                all_tokens.append('[BRANCH_CLOSE]')
                all_token_ids.append(self.branch_close_id)
                all_residue_ids.append(-1)
                all_branch_depths.append(0)  # BRANCH_CLOSE at depth 0 (closing structure)
                all_linkage_types.append(0)
            
            # Tokenize linkages (residue_id = -2 for linkage tokens)
            for linkage in linkages:
                if not linkage:
                    continue
                # Parse linkage type for this linkage
                link_type = parse_linkage_type(linkage)
                
                all_tokens.append('[LINK]')
                all_token_ids.append(self.link_token_id)
                all_residue_ids.append(-2)
                all_branch_depths.append(0)  # Linkages at depth 0 (structural)
                all_linkage_types.append(link_type)  # Actual linkage type here!
                
                for char in linkage:
                    if char in self.token_to_id:
                        all_tokens.append(char)
                        all_token_ids.append(self.token_to_id[char])
                        all_residue_ids.append(-2)
                        all_branch_depths.append(0)
                        all_linkage_types.append(link_type)
                    else:
                        all_tokens.append('[UNK]')
                        all_token_ids.append(self.unk_token_id)
                        all_residue_ids.append(-2)
                        all_branch_depths.append(0)
                        all_linkage_types.append(link_type)
        
        except Exception:
            pass
        
        # End token
        all_tokens.append('[END]')
        all_token_ids.append(self.end_token_id)
        all_residue_ids.append(-1)
        all_branch_depths.append(0)
        all_linkage_types.append(0)
        
        # Truncate if necessary
        if len(all_token_ids) > max_length:
            all_tokens = all_tokens[:max_length-1] + ['[END]']
            all_token_ids = all_token_ids[:max_length-1] + [self.end_token_id]
            all_residue_ids = all_residue_ids[:max_length-1] + [-1]
            all_branch_depths = all_branch_depths[:max_length-1] + [0]
            all_linkage_types = all_linkage_types[:max_length-1] + [0]
        
        length = len(all_token_ids)
        attention_mask = [1] * length
        
        # Pad to max_length
        padding_length = max_length - length
        if padding_length > 0:
            all_tokens = all_tokens + ['[PAD]'] * padding_length
            all_token_ids = all_token_ids + [self.pad_token_id] * padding_length
            all_residue_ids = all_residue_ids + [-1] * padding_length
            all_branch_depths = all_branch_depths + [0] * padding_length
            all_linkage_types = all_linkage_types + [0] * padding_length
            attention_mask = attention_mask + [0] * padding_length
        
        return {
            'tokens': all_tokens,
            'token_ids': all_token_ids,
            'residue_ids': all_residue_ids,
            'branch_depths': all_branch_depths,  # NEW: Branch depth for each token
            'linkage_types': all_linkage_types,  # NEW: Linkage type for each token
            'monosaccharide_names': monosaccharide_names,
            'num_residues': len(monosaccharide_names),
            'length': length,
            'attention_mask': attention_mask,
            'is_branched': is_branched,
            'has_unk_mod': '[UNK_MOD]' in all_tokens[:length]
        }
    
    def batch_tokenize(self, wurcs_list: List[str], max_length: int = 512) -> Dict:
        """
        Tokenize a batch of WURCS strings.
        
        Args:
            wurcs_list: List of WURCS strings
            max_length: Maximum sequence length
        
        Returns:
            Dict with batched tensors (as lists)
        """
        results = [self.tokenize(w, max_length) for w in wurcs_list]
        
        return {
            'token_ids': [r['token_ids'] for r in results],
            'attention_mask': [r['attention_mask'] for r in results],
            'lengths': [r['length'] for r in results]
        }
    
    def decode(self, token_ids: List[int]) -> str:
        """
        Decode token IDs back to tokens (for debugging).
        
        Args:
            token_ids: List of token IDs
        
        Returns:
            Space-separated token string
        """
        tokens = [self.id_to_token.get(tid, '[UNK]') for tid in token_ids]
        # Remove padding
        tokens = [t for t in tokens if t != '[PAD]']
        return ' '.join(tokens)


def create_tokenizer(vocab_path: str = None) -> WURCSTokenizer:
    """
    Create a tokenizer with default vocabulary path.
    
    Args:
        vocab_path: Optional path to vocabulary.json
    
    Returns:
        WURCSTokenizer instance
    """
    if vocab_path is None:
        # Try default locations
        default_paths = [
            Path(__file__).parent.parent.parent / 'data' / 'vocabulary.json',
            Path(__file__).parent.parent / 'data' / 'vocabulary.json',
        ]
        for p in default_paths:
            if p.exists():
                vocab_path = str(p)
                break
        else:
            raise FileNotFoundError("vocabulary.json not found. Please specify path.")
    
    return WURCSTokenizer(vocab_path)


if __name__ == '__main__':
    # Test tokenizer
    import sys
    
    # Find vocab
    vocab_path = Path(__file__).parent.parent.parent / 'data' / 'vocabulary.json'
    if not vocab_path.exists():
        print(f"Vocabulary not found at {vocab_path}")
        sys.exit(1)
    
    tokenizer = WURCSTokenizer(str(vocab_path))
    print(f"Loaded tokenizer with {tokenizer.vocab_size} tokens")
    
    # Test cases
    test_cases = [
        "WURCS=2.0/2,2,1/[a212h-1b_1-5][a2211m-1a_1-5]/1-2/a2-b1",
        "WURCS=2.0/3,3,2/[a2122h-1b_1-5_2*NCC/3=O][a2112h-1a_1-5][a2211m-1a_1-5]/1-2-3/a4-b1_b3-c1",
    ]
    
    for wurcs in test_cases:
        result = tokenizer.tokenize(wurcs, max_length=64)
        print(f"\nWURCS: {wurcs[:50]}...")
        print(f"  Length: {result['length']}")
        print(f"  Branched: {result['is_branched']}")
        print(f"  Tokens: {result['tokens'][:15]}...")
        print(f"  Token IDs: {result['token_ids'][:15]}...")