""" Standalone WURCS Tokenizer for Fine-tuning This tokenizer matches the v3.1 pre-training tokenization exactly: - Uses bracket-counting parser (handles '/' in modifications) - 167 tokens (including 58 whole-token modifications) - Same special tokens: [PAD], [UNK], [START], [END], [BRANCH_OPEN], [BRANCH_CLOSE], etc. - Detects and inserts branch tokens based on WURCS topology Usage: tokenizer = WURCSTokenizer('data/vocabulary.json') result = tokenizer.tokenize(wurcs_string) # result = {'tokens': [...], 'token_ids': [...], 'length': N} """ import json import re from pathlib import Path from typing import Dict, List, Optional, Tuple from collections import defaultdict def parse_wurcs_sections(wurcs: str) -> List[str]: """ Parse WURCS string into sections by counting brackets. This correctly handles modifications with '/' characters like NCC/3=O. The simple regex approach fails because '/' appears inside modifications. Args: wurcs: WURCS string Returns: List of sections [version, counts, residues, topology, linkages] """ sections = [] current_section = [] bracket_depth = 0 for char in wurcs: if char == '[': bracket_depth += 1 current_section.append(char) elif char == ']': bracket_depth -= 1 current_section.append(char) elif char == '/' and bracket_depth == 0: # This '/' is a section separator (not inside a modification) sections.append(''.join(current_section)) current_section = [] else: current_section.append(char) # Add last section if current_section: sections.append(''.join(current_section)) return sections def extract_residues_from_wurcs(wurcs: str) -> List[str]: """ Extract residue definitions from WURCS string. Uses bracket-counting parser to correctly handle modifications containing '/' characters (e.g., 2*NCC/3=O, 4*OSO/3=O/3=O). Args: wurcs: WURCS string Returns: List of residue strings (e.g., ['a2122h-1a_1-5_2*NCC/3=O', ...]) """ sections = parse_wurcs_sections(wurcs) if len(sections) < 3: return [] # Section 2 contains residue definitions: [type1][type2]... residues_section = sections[2] # Find all residues [...] residues = re.findall(r'\[([^\]]+)\]', residues_section) return residues def extract_linkages_from_wurcs(wurcs: str) -> List[str]: """ Extract linkage patterns from WURCS string. Uses bracket-counting parser to correctly handle modifications containing '/' characters. Args: wurcs: WURCS string Returns: List of linkage strings (e.g., ['a4-b1', 'b3-c2']) """ sections = parse_wurcs_sections(wurcs) if len(sections) < 5: return [] # Section 4 contains linkages: a4-b1_b3-c2_... linkage_section = sections[4] # Split by underscore linkages = linkage_section.split('_') return [l for l in linkages if l] def detect_branch_points(linkages: List[str]) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]: """ Detect branch points from linkages. A residue is a branch point if it has >1 outgoing connections. Args: linkages: List of linkage strings (e.g., ['a4-b1', 'a3-c2']) Returns: Tuple of: - branch_points: Dict mapping residue letter to list of children (only for >1 children) - outgoing: Dict mapping all residue letters to their children """ outgoing = defaultdict(list) for link in linkages: # Parse link like "a3-b1" or "a?-b1" match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link) if match: from_res, from_pos, to_res, to_pos = match.groups() outgoing[from_res].append(to_res) # Branch points have >1 children branch_points = {k: v for k, v in outgoing.items() if len(v) > 1} return branch_points, dict(outgoing) def compute_residue_depths(linkages: List[str]) -> Dict[str, int]: """ Compute depth of each residue in the glycan tree using BFS from root. Root is typically 'a' (first residue). Depth 0 = root, 1 = children of root, etc. Args: linkages: List of linkage strings (e.g., ['a4-b1', 'b3-c1']) Returns: Dict mapping residue letter to its depth in the tree """ # Build parent-child graph from linkages children = defaultdict(list) all_residues = set() for link in linkages: match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link) if match: from_res, _, to_res, _ = match.groups() children[from_res].append(to_res) all_residues.add(from_res) all_residues.add(to_res) # Find root (residue with no parent) - usually 'a' child_set = set() for kids in children.values(): child_set.update(kids) roots = all_residues - child_set root = min(roots) if roots else 'a' # Default to 'a' if can't determine # BFS from root to compute depths depths = {root: 0} queue = [root] while queue: current = queue.pop(0) for child in children.get(current, []): if child not in depths: depths[child] = depths[current] + 1 queue.append(child) return depths def parse_linkage_type(link: str) -> int: """ Parse a linkage string to extract linkage type ID. Linkage types based on anomeric configuration and position: - α1-2, α1-3, α1-4, α1-6, β1-2, β1-3, β1-4, β1-6, unknown WURCS uses position numbers (1-6) in linkage patterns like 'a4-b1'. We extract the positions to determine linkage type. Args: link: Linkage string like 'a4-b1' or 'a3-b1' Returns: Linkage type ID (0-8): 0 = 1-2, 1 = 1-3, 2 = 1-4, 3 = 1-6, 4 = 2-3, 5 = 2-6, 6 = 3-6, 7 = other, 8 = unknown """ LINKAGE_TYPES = { (1, 2): 0, (2, 1): 0, (1, 3): 1, (3, 1): 1, (1, 4): 2, (4, 1): 2, (1, 6): 3, (6, 1): 3, (2, 3): 4, (3, 2): 4, (2, 6): 5, (6, 2): 5, (3, 6): 6, (6, 3): 6, } match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link) if match: _, from_pos, _, to_pos = match.groups() try: pos_tuple = (int(from_pos), int(to_pos)) return LINKAGE_TYPES.get(pos_tuple, 7) # 7 = other except ValueError: return 8 # Unknown (contains '?') return 8 class WURCSTokenizer: """ WURCS tokenizer matching v3.1 pre-training exactly. Key features: - Bracket-counting parser for '/' handling in modifications - Branch token insertion based on topology - Whole-token modifications (58 total) """ def __init__(self, vocab_path: str): """ Initialize tokenizer with vocabulary. Args: vocab_path: Path to vocabulary.json (same as used in v3.1 pre-training) """ with open(vocab_path, 'r') as f: self.vocab = json.load(f) # Build token_to_id mapping self.token_to_id = {} # Add special tokens for token, idx in self.vocab.get('special_tokens', {}).items(): self.token_to_id[token] = idx # Add skeleton atoms for token, idx in self.vocab.get('skeleton_atoms', {}).items(): self.token_to_id[token] = idx # Add linkage atoms for token, idx in self.vocab.get('linkage_atoms', {}).items(): self.token_to_id[token] = idx # Add anomeric symbols for token, idx in self.vocab.get('anomeric_symbols', {}).items(): self.token_to_id[token] = idx # Add anomeric positions for token, idx in self.vocab.get('anomeric_positions', {}).items(): self.token_to_id[token] = idx # Add whole modifications for token, idx in self.vocab.get('whole_modifications', {}).items(): self.token_to_id[token] = idx # Build reverse mapping self.id_to_token = {idx: token for token, idx in self.token_to_id.items()} # Set special token IDs self.pad_token_id = self.token_to_id.get('[PAD]', 0) self.unk_token_id = self.token_to_id.get('[UNK]', 1) self.start_token_id = self.token_to_id.get('[START]', 2) self.end_token_id = self.token_to_id.get('[END]', 3) self.branch_open_id = self.token_to_id.get('[BRANCH_OPEN]', 5) self.branch_close_id = self.token_to_id.get('[BRANCH_CLOSE]', 6) self.mod_token_id = self.token_to_id.get('[MOD]', 9) self.link_token_id = self.token_to_id.get('[LINK]', 7) self.vocab_size = self.vocab.get('metadata', {}).get('total_vocab_size', len(self.token_to_id)) def tokenize(self, wurcs: str, max_length: int = 512) -> Dict: """ Tokenize a WURCS string. Matches v3.1 pre-training tokenization exactly: - Bracket-counting parser for '/' handling - Branch token insertion based on topology - Whole-token modifications Args: wurcs: WURCS string to tokenize max_length: Maximum sequence length (truncate if longer) Returns: Dict with: - tokens: List of token strings - token_ids: List of token IDs - residue_ids: List of residue IDs for each token (-1=special, -2=linkage) - monosaccharide_names: List of monosaccharide names for each residue - length: Number of tokens - attention_mask: 1 for real tokens, 0 for padding - is_branched: Whether the glycan has branches """ all_tokens = [] all_token_ids = [] all_residue_ids = [] # Track which residue each token belongs to all_branch_depths = [] # NEW: Branch depth for each token (0=root, 1=child, etc.) all_linkage_types = [] # NEW: Linkage type for each token (0-8) monosaccharide_names = [] # Names of monosaccharides in order is_branched = False residue_depths = {} # Will be computed from linkages # Common monosaccharide patterns based on WURCS skeleton MONO_PATTERNS = { 'a2122h': 'Glc', 'a2112h': 'Gal', 'a1221m': 'Fuc', 'a2211m': 'Rha', 'a212h': 'Xyl', 'a21d2h': 'Man', 'a2112m': 'Ara', 'a2d21h': 'Ido', 'axxxxh': 'Hex', 'Aad21122h': 'Neu5Ac', 'Aad21122h-2': 'Neu5Gc', } # Start token (residue_id = -1 for special tokens, branch_depth = 0) all_tokens.append('[START]') all_token_ids.append(self.start_token_id) all_residue_ids.append(-1) all_branch_depths.append(0) # Special tokens at depth 0 all_linkage_types.append(0) # No linkage for special tokens try: # Extract residues and linkages residues = extract_residues_from_wurcs(wurcs) linkages = extract_linkages_from_wurcs(wurcs) # Detect branching and compute residue depths branch_points, outgoing = detect_branch_points(linkages) is_branched = len(branch_points) > 0 # NEW: Compute depth of each residue in the tree residue_depths = compute_residue_depths(linkages) # Get residue order from topology sections = parse_wurcs_sections(wurcs) topology = sections[3] if len(sections) > 3 else "" residue_order = [int(x) for x in topology.split('-') if x.isdigit()] # Tokenize each residue in order residue_letters = 'abcdefghijklmnopqrstuvwxyz' current_residue_id = 0 for i, res_idx in enumerate(residue_order): if res_idx < 1 or res_idx > len(residues): continue residue = residues[res_idx - 1] res_letter = residue_letters[i] if i < len(residue_letters) else 'z' # Determine monosaccharide name from skeleton mono_name = '' parts = residue.split('_') if parts: skeleton = parts[0].split('-')[0] if '-' in parts[0] else parts[0] # Check for modifications (GlcNAc, GalNAc, etc.) has_nac = any('NCC/3=O' in p or 'NAc' in p for p in parts) has_s = any('OSO' in p for p in parts) for pattern, name in MONO_PATTERNS.items(): if skeleton.startswith(pattern) or skeleton == pattern: mono_name = name break # Adjust for modifications if has_nac and mono_name in ['Glc', 'Gal']: mono_name = mono_name + 'NAc' elif has_s: mono_name = mono_name + 'S' monosaccharide_names.append(mono_name) # Check if this is a branch point - add [BRANCH_OPEN] # Get depth for this residue res_depth = residue_depths.get(res_letter, 0) if res_letter in branch_points: all_tokens.append('[BRANCH_OPEN]') all_token_ids.append(self.branch_open_id) all_residue_ids.append(-1) # Branch tokens are special all_branch_depths.append(res_depth) all_linkage_types.append(0) # Parse residue: skeleton-anomer_ring_mods # Part 0: Skeleton and anomer (e.g., "a2122h-1b") if parts: main_part = parts[0] for char in main_part: if char in self.token_to_id: all_tokens.append(char) all_token_ids.append(self.token_to_id[char]) all_residue_ids.append(current_residue_id) all_branch_depths.append(res_depth) all_linkage_types.append(0) # Residue tokens have no linkage type elif char == '-': if '-' in self.token_to_id: all_tokens.append('-') all_token_ids.append(self.token_to_id['-']) all_residue_ids.append(current_residue_id) all_branch_depths.append(res_depth) all_linkage_types.append(0) else: all_tokens.append('[UNK]') all_token_ids.append(self.unk_token_id) all_residue_ids.append(current_residue_id) all_branch_depths.append(res_depth) all_linkage_types.append(0) # Parts 2+: Modifications (skip part 1 which is ring closure) if len(parts) > 2: for part in parts[2:]: if '*' in part: all_tokens.append('[MOD]') all_token_ids.append(self.mod_token_id) all_residue_ids.append(current_residue_id) all_branch_depths.append(res_depth) all_linkage_types.append(0) if part in self.token_to_id: all_tokens.append(part) all_token_ids.append(self.token_to_id[part]) all_residue_ids.append(current_residue_id) all_branch_depths.append(res_depth) all_linkage_types.append(0) else: all_tokens.append('[UNK_MOD]') all_token_ids.append(self.token_to_id.get('[UNK_MOD]', self.unk_token_id)) all_residue_ids.append(current_residue_id) all_branch_depths.append(res_depth) all_linkage_types.append(0) current_residue_id += 1 # Add branch close tokens at the end branch_opens = all_tokens.count('[BRANCH_OPEN]') for _ in range(branch_opens): all_tokens.append('[BRANCH_CLOSE]') all_token_ids.append(self.branch_close_id) all_residue_ids.append(-1) all_branch_depths.append(0) # BRANCH_CLOSE at depth 0 (closing structure) all_linkage_types.append(0) # Tokenize linkages (residue_id = -2 for linkage tokens) for linkage in linkages: if not linkage: continue # Parse linkage type for this linkage link_type = parse_linkage_type(linkage) all_tokens.append('[LINK]') all_token_ids.append(self.link_token_id) all_residue_ids.append(-2) all_branch_depths.append(0) # Linkages at depth 0 (structural) all_linkage_types.append(link_type) # Actual linkage type here! for char in linkage: if char in self.token_to_id: all_tokens.append(char) all_token_ids.append(self.token_to_id[char]) all_residue_ids.append(-2) all_branch_depths.append(0) all_linkage_types.append(link_type) else: all_tokens.append('[UNK]') all_token_ids.append(self.unk_token_id) all_residue_ids.append(-2) all_branch_depths.append(0) all_linkage_types.append(link_type) except Exception: pass # End token all_tokens.append('[END]') all_token_ids.append(self.end_token_id) all_residue_ids.append(-1) all_branch_depths.append(0) all_linkage_types.append(0) # Truncate if necessary if len(all_token_ids) > max_length: all_tokens = all_tokens[:max_length-1] + ['[END]'] all_token_ids = all_token_ids[:max_length-1] + [self.end_token_id] all_residue_ids = all_residue_ids[:max_length-1] + [-1] all_branch_depths = all_branch_depths[:max_length-1] + [0] all_linkage_types = all_linkage_types[:max_length-1] + [0] length = len(all_token_ids) attention_mask = [1] * length # Pad to max_length padding_length = max_length - length if padding_length > 0: all_tokens = all_tokens + ['[PAD]'] * padding_length all_token_ids = all_token_ids + [self.pad_token_id] * padding_length all_residue_ids = all_residue_ids + [-1] * padding_length all_branch_depths = all_branch_depths + [0] * padding_length all_linkage_types = all_linkage_types + [0] * padding_length attention_mask = attention_mask + [0] * padding_length return { 'tokens': all_tokens, 'token_ids': all_token_ids, 'residue_ids': all_residue_ids, 'branch_depths': all_branch_depths, # NEW: Branch depth for each token 'linkage_types': all_linkage_types, # NEW: Linkage type for each token 'monosaccharide_names': monosaccharide_names, 'num_residues': len(monosaccharide_names), 'length': length, 'attention_mask': attention_mask, 'is_branched': is_branched, 'has_unk_mod': '[UNK_MOD]' in all_tokens[:length] } def batch_tokenize(self, wurcs_list: List[str], max_length: int = 512) -> Dict: """ Tokenize a batch of WURCS strings. Args: wurcs_list: List of WURCS strings max_length: Maximum sequence length Returns: Dict with batched tensors (as lists) """ results = [self.tokenize(w, max_length) for w in wurcs_list] return { 'token_ids': [r['token_ids'] for r in results], 'attention_mask': [r['attention_mask'] for r in results], 'lengths': [r['length'] for r in results] } def decode(self, token_ids: List[int]) -> str: """ Decode token IDs back to tokens (for debugging). Args: token_ids: List of token IDs Returns: Space-separated token string """ tokens = [self.id_to_token.get(tid, '[UNK]') for tid in token_ids] # Remove padding tokens = [t for t in tokens if t != '[PAD]'] return ' '.join(tokens) def create_tokenizer(vocab_path: str = None) -> WURCSTokenizer: """ Create a tokenizer with default vocabulary path. Args: vocab_path: Optional path to vocabulary.json Returns: WURCSTokenizer instance """ if vocab_path is None: # Try default locations default_paths = [ Path(__file__).parent.parent.parent / 'data' / 'vocabulary.json', Path(__file__).parent.parent / 'data' / 'vocabulary.json', ] for p in default_paths: if p.exists(): vocab_path = str(p) break else: raise FileNotFoundError("vocabulary.json not found. Please specify path.") return WURCSTokenizer(vocab_path) if __name__ == '__main__': # Test tokenizer import sys # Find vocab vocab_path = Path(__file__).parent.parent.parent / 'data' / 'vocabulary.json' if not vocab_path.exists(): print(f"Vocabulary not found at {vocab_path}") sys.exit(1) tokenizer = WURCSTokenizer(str(vocab_path)) print(f"Loaded tokenizer with {tokenizer.vocab_size} tokens") # Test cases test_cases = [ "WURCS=2.0/2,2,1/[a212h-1b_1-5][a2211m-1a_1-5]/1-2/a2-b1", "WURCS=2.0/3,3,2/[a2122h-1b_1-5_2*NCC/3=O][a2112h-1a_1-5][a2211m-1a_1-5]/1-2-3/a4-b1_b3-c1", ] for wurcs in test_cases: result = tokenizer.tokenize(wurcs, max_length=64) print(f"\nWURCS: {wurcs[:50]}...") print(f" Length: {result['length']}") print(f" Branched: {result['is_branched']}") print(f" Tokens: {result['tokens'][:15]}...") print(f" Token IDs: {result['token_ids'][:15]}...")