#!/usr/bin/env python3 """ WURCS-BPE Tokenizer A hybrid tokenizer that learns semantic subwords from WURCS while preserving the ability to handle rare/novel glycan structures character-by-character. Key features: 1. Pre-tokenization: Split WURCS into semantic units (residues, linkages, mods) 2. BPE: Learn subword merges from corpus 3. Character fallback: Handle novel structures 4. Tree embeddings: Preserve branch_depth and linkage_type per token Usage: # Train BPE on corpus tokenizer = WURCSBPETokenizer.train_from_corpus( wurcs_strings, num_merges=500, output_path="bpe_vocabulary.json" ) # Tokenize result = tokenizer.tokenize(wurcs_string) """ import json import re from collections import Counter, defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple, Set import pickle class WURCSPreTokenizer: """ Pre-tokenize WURCS into semantic units before BPE. WURCS format: WURCS=2.0/count/[residues]/indices/linkages We split into: - Residues: [a2122h-1b_1-5_2*NCC/3=O] -> one unit per [] - Linkages: a4-b1 -> one unit per linkage - Special markers: [BRANCH_OPEN], [BRANCH_CLOSE], etc. """ # Residue patterns for common monosaccharides RESIDUE_PATTERN = re.compile(r'\[([^\]]+)\]') LINKAGE_PATTERN = re.compile(r'([a-z])(\d+|\?)-([a-z])(\d+|\?)') def __init__(self): self.special_tokens = { '[PAD]': 0, '[UNK]': 1, '[START]': 2, '[END]': 3, '[MASK]': 4, '[BRANCH_OPEN]': 5, '[BRANCH_CLOSE]': 6, '[LINK]': 7, '[MOD]': 8, '[RESIDUE_ERROR]': 9, } def pre_tokenize(self, wurcs: str) -> List[Dict]: """ Pre-tokenize WURCS into semantic units. Returns list of dicts with: - text: The unit text - type: 'special', 'residue', 'linkage', 'mod', 'index' - residue_id: Which residue this belongs to (-1 for special, -2 for linkage) - branch_depth: Tree depth (computed later) """ units = [] # Add start token units.append({ 'text': '[START]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0, }) # Parse WURCS sections if not wurcs.startswith('WURCS='): units.append({'text': '[RESIDUE_ERROR]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0}) units.append({'text': '[END]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0}) return units try: parts = self._split_wurcs_sections(wurcs) if len(parts) < 4: return [{'text': '[ERROR]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0}] # parts: WURCS=2.0/3,3,2/[a2122h-1b_1-5][a2122h-1a_1-5][a1122h-1b_1-5]/1-2-3-1/a4-b1_b3-c1_c4-d1 # section 2: residue definitions # section 3: indices # section 4: linkages (optional) version = parts[0] # WURCS=2.0 counts = parts[1] # residue_count,node_count,link_count residue_defs = parts[2] # [res1][res2]... indices = parts[3] # 1-2-3-1 linkages = parts[4] if len(parts) > 4 else "" # a4-b1_b3-c1 # Parse residue definitions residue_list = self.RESIDUE_PATTERN.findall(residue_defs) # Parse linkages to compute branch structure linkage_list = linkages.split('_') if linkages else [] branch_points, residue_depths, linkage_types_map, adj = self._analyze_tree_structure(linkage_list, num_residues=len(residue_list)) # Compute distance matrix and cache it based on the linkage string (structure) # This is the most expensive part, so we cache it if not hasattr(self, '_dist_cache'): self._dist_cache = {} if linkages not in self._dist_cache: self._dist_cache[linkages] = self._compute_distance_matrix(adj, len(residue_list)) dist_matrix_raw = self._dist_cache[linkages] # Parse indices to map positions to residue definitions index_list = indices.split('-') if indices else [] # Process each residue instance residue_letter = ord('a') for idx, res_idx in enumerate(index_list): current_residue_id = idx res_letter = chr(residue_letter + idx) # Check if this is a branch point - add branch marker before if res_letter in branch_points and branch_points[res_letter] > 0: for _ in range(branch_points[res_letter]): units.append({ 'text': '[BRANCH_OPEN]', 'type': 'special', 'residue_id': -1, 'branch_depth': residue_depths.get(res_letter, 0), 'linkage_type': 0, }) # Get residue definition try: res_def_idx = int(res_idx) - 1 # 1-indexed to 0-indexed res_def = residue_list[res_def_idx] if res_def_idx < len(residue_list) else "" except (ValueError, IndexError): res_def = "" # Split residue into base and modifications res_parts = res_def.split('_') base = res_parts[0] if res_parts else res_def mods = res_parts[1:] if len(res_parts) > 1 else [] # Add residue base as a single unit depth = residue_depths.get(res_letter, 0) units.append({ 'text': base, 'type': 'residue', 'residue_id': current_residue_id, 'branch_depth': depth, 'linkage_type': 0, }) # Add modifications for mod in mods: units.append({ 'text': mod, 'type': 'mod', 'residue_id': current_residue_id, 'branch_depth': depth, 'linkage_type': 0, }) # Store distance matrix in units for easy access in tokenizer if units: # Find first residue unit or just use START units[0]['distance_matrix'] = dist_matrix_raw # Add linkages for link in linkage_list: if not link: continue # Parse linkage type lt = self._parse_linkage_type(link) units.append({ 'text': link, 'type': 'linkage', 'residue_id': -2, 'branch_depth': 0, 'linkage_type': lt, }) except Exception: # Fallback for truly broken WURCS pass # Add end token units.append({ 'text': '[END]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0, }) return units def _split_wurcs_sections(self, wurcs: str) -> List[str]: """Split WURCS string into sections, handling nested brackets.""" # Remove WURCS= prefix if wurcs.startswith('WURCS='): wurcs = wurcs[6:] sections = [] current = "" bracket_depth = 0 for char in wurcs: if char == '[': bracket_depth += 1 current += char elif char == ']': bracket_depth -= 1 current += char elif char == '/' and bracket_depth == 0: sections.append(current) current = "" else: current += char if current: sections.append(current) return sections def _analyze_tree_structure(self, linkages: List[str], num_residues: int) -> Tuple[Dict, Dict, Dict, Dict]: """Analyze linkages to compute branch points and residue depths.""" branch_points = defaultdict(int) # residue -> number of children children = defaultdict(list) all_residues = set() linkage_types = {} for link in linkages: match = self.LINKAGE_PATTERN.match(link) if match: from_res, from_pos, to_res, to_pos = match.groups() children[from_res].append(to_res) all_residues.add(from_res) all_residues.add(to_res) # Store linkage type linkage_types[link] = self._parse_linkage_type(link) # Build adjacency list for BFS adj = defaultdict(list) for link in linkages: match = self.LINKAGE_PATTERN.match(link) if match: u = ord(match.group(1)) - ord('a') v = ord(match.group(3)) - ord('a') if 0 <= u < num_residues and 0 <= v < num_residues: adj[u].append(v) adj[v].append(u) # Find branch points (residues with >1 child) for res, kids in children.items(): if len(kids) > 1: branch_points[res] = len(kids) - 1 # Number of extra branches # Compute depths using BFS # Find root (residue with no parent) child_set = set() for kids in children.values(): child_set.update(kids) roots = all_residues - child_set root = min(roots) if roots else 'a' depths = {root: 0} queue = [root] while queue: current = queue.pop(0) for child in children.get(current, []): if child not in depths: depths[child] = depths[current] + 1 queue.append(child) return branch_points, depths, linkage_types, adj def _compute_distance_matrix(self, adj: Dict[int, List[int]], num_residues: int) -> List[List[int]]: """ Compute shortest path distance (number of bonds) between all residue pairs using BFS. """ if num_residues == 0: return [] dist_matrix = [[-1] * num_residues for _ in range(num_residues)] for i in range(num_residues): dist_matrix[i][i] = 0 queue = [(i, 0)] visited = {i} while queue: curr, d = queue.pop(0) dist_matrix[i][curr] = d for neighbor in adj[curr]: if neighbor not in visited: visited.add(neighbor) queue.append((neighbor, d + 1)) return dist_matrix def _compute_distance_matrix_OLD(self, linkages: List[str], num_residues: int) -> List[List[int]]: """ Compute shortest path distance (number of bonds) between all residue pairs. Returns a symmetric N x N matrix where N is num_residues. Values are integers (number of steps). 0 on diagonal. -1 if unreachable (shouldn't happen in single tree). """ if num_residues == 0: return [] # Initialize adjacency list adj = defaultdict(list) for link in linkages: match = self.LINKAGE_PATTERN.match(link) if match: # WURCS indices are 1-based letters (a=1, b=2...) from_res_char, _, to_res_char, _ = match.groups() # Convert char to 0-based index u = ord(from_res_char) - ord('a') v = ord(to_res_char) - ord('a') # Undirected graph for structural distance if 0 <= u < num_residues and 0 <= v < num_residues: adj[u].append(v) adj[v].append(u) # Compute All-Pairs Shortest Path (BFS from each node is fine for small N) # Glycans are small (N ~ 5-20 usually), so O(N^2) BFS is cheap. dist_matrix = [[-1] * num_residues for _ in range(num_residues)] for i in range(num_residues): dist_matrix[i][i] = 0 queue = [(i, 0)] visited = {i} while queue: curr, d = queue.pop(0) dist_matrix[i][curr] = d for neighbor in adj[curr]: if neighbor not in visited: visited.add(neighbor) queue.append((neighbor, d + 1)) return dist_matrix def _parse_linkage_type(self, link: str) -> int: """Parse linkage string to get type ID.""" LINKAGE_TYPES = { (1, 2): 0, (2, 1): 0, (1, 3): 1, (3, 1): 1, (1, 4): 2, (4, 1): 2, (1, 6): 3, (6, 1): 3, (2, 3): 4, (3, 2): 4, (2, 6): 5, (6, 2): 5, (3, 6): 6, (6, 3): 6, } match = self.LINKAGE_PATTERN.match(link) if match: _, from_pos, _, to_pos = match.groups() try: pos_tuple = (int(from_pos), int(to_pos)) return LINKAGE_TYPES.get(pos_tuple, 7) except ValueError: return 8 # Unknown return 8 class WURCSBPETokenizer: """ BPE tokenizer for WURCS with tree-aware embeddings. """ def __init__(self, vocab_path: Optional[str] = None): self.pre_tokenizer = WURCSPreTokenizer() # Special tokens (fixed) self.special_tokens = self.pre_tokenizer.special_tokens # BPE vocabulary self.token_to_id: Dict[str, int] = {} self.id_to_token: Dict[int, str] = {} self.merges: List[Tuple[str, str]] = [] if vocab_path: self.load_vocab(vocab_path) else: # Initialize with special tokens only self.token_to_id = dict(self.special_tokens) self.id_to_token = {v: k for k, v in self.token_to_id.items()} @classmethod def train_from_corpus( cls, wurcs_strings: List[str], num_merges: int = 500, output_path: Optional[str] = None, min_frequency: int = 2, max_token_length: Optional[int] = None, ) -> 'WURCSBPETokenizer': """ Train BPE on a corpus of WURCS strings. Args: wurcs_strings: List of WURCS strings num_merges: Number of BPE merge operations output_path: Optional path to save vocabulary min_frequency: Minimum frequency for a token to be kept max_token_length: Maximum length of a merged token (None = no limit) Returns: Trained tokenizer """ tokenizer = cls() pre_tok = WURCSPreTokenizer() print(f"Training BPE on {len(wurcs_strings)} WURCS strings...") # Step 1: Pre-tokenize all strings to get semantic units all_units = [] for wurcs in wurcs_strings: units = pre_tok.pre_tokenize(wurcs) for unit in units: if unit['type'] != 'special': all_units.append(unit['text']) # Step 2: Count unit frequencies unit_counts = Counter(all_units) print(f"Found {len(unit_counts)} unique units") # Step 3: Initialize vocabulary with characters from all units char_vocab = set() for unit in unit_counts: for char in unit: char_vocab.add(char) # Build initial vocab: special tokens + characters vocab_id = len(tokenizer.special_tokens) for char in sorted(char_vocab): tokenizer.token_to_id[char] = vocab_id tokenizer.id_to_token[vocab_id] = char vocab_id += 1 print(f"Initial vocab size: {vocab_id} (special + characters)") # Step 4: Convert units to character sequences word_freqs = {} for unit, count in unit_counts.items(): if count >= min_frequency: # Split into characters with space separator chars = tuple(unit) word_freqs[chars] = count # Step 5: BPE merging merges = [] for merge_idx in range(num_merges): # Count pairs pair_counts = Counter() for word, freq in word_freqs.items(): for i in range(len(word) - 1): pair = (word[i], word[i + 1]) pair_counts[pair] += freq if not pair_counts: break # Find most frequent pair best_pair = pair_counts.most_common(1)[0][0] best_count = pair_counts[best_pair] if best_count < min_frequency: break # Merge pair new_token = best_pair[0] + best_pair[1] # Check length constraint if max_token_length and len(new_token) > max_token_length: # remove this pair from consideration for this iteration and future? # Actually, skipping it here is tricky because we need to ignore it in pair_counts next time # Simpler: Just skip adding it to merges and modify word_freqs? # No, if we don't merge, we just continue to the next best pair in THIS iteration. # But pair_counts is already computed. # We need to loop until we find a valid pair or run out # In this simple implementation, let's just skip this merge efficiently # We need to find the NEXT most common pair. # Re-do finding best pair loop found_valid_pair = False for pair, count in pair_counts.most_common(): token_candidate = pair[0] + pair[1] if max_token_length and len(token_candidate) > max_token_length: continue # Skip too long if count < min_frequency: break # Stop if frequency too low # Found valid pair best_pair = pair best_count = count new_token = token_candidate found_valid_pair = True break if not found_valid_pair: print(f" Stopping early: No more pairs satisfy max_token_length={max_token_length}") break # Final check before merging (in case we didn't enter the if block but updated vars) # Actually the logic above handles it. If we entered the block, we either found a new best_pair or broke. merges.append(best_pair) # Add to vocab tokenizer.token_to_id[new_token] = vocab_id tokenizer.id_to_token[vocab_id] = new_token vocab_id += 1 # Update word_freqs new_word_freqs = {} for word, freq in word_freqs.items(): new_word = [] i = 0 while i < len(word): if i < len(word) - 1 and word[i] == best_pair[0] and word[i + 1] == best_pair[1]: new_word.append(new_token) i += 2 else: new_word.append(word[i]) i += 1 new_word_freqs[tuple(new_word)] = freq word_freqs = new_word_freqs if (merge_idx + 1) % 100 == 0: print(f" Merge {merge_idx + 1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' -> '{new_token}' (count={best_count})") tokenizer.merges = merges print(f"Final vocab size: {len(tokenizer.token_to_id)}") # Save if requested if output_path: tokenizer.save_vocab(output_path) return tokenizer def apply_bpe(self, text: str) -> List[str]: """Apply BPE merges to a text string.""" if text in self.token_to_id: return [text] # Split into characters tokens = list(text) # Apply merges for pair in self.merges: new_tokens = [] i = 0 while i < len(tokens): if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]: new_tokens.append(pair[0] + pair[1]) i += 2 else: new_tokens.append(tokens[i]) i += 1 tokens = new_tokens return tokens def tokenize(self, wurcs: str, max_length: int = 256) -> Dict: """ Tokenize a WURCS string. Returns: Dict with: - tokens: List of token strings - token_ids: List of token IDs - residue_ids: List of residue IDs - branch_depths: List of branch depths - linkage_types: List of linkage types - attention_mask: Attention mask """ # Pre-tokenize units = self.pre_tokenizer.pre_tokenize(wurcs) tokens = [] token_ids = [] residue_ids = [] branch_depths = [] linkage_types = [] for unit in units: if unit['type'] == 'special': # Special tokens stay as-is tok = unit['text'] tokens.append(tok) token_ids.append(self.token_to_id.get(tok, self.token_to_id['[UNK]'])) residue_ids.append(unit['residue_id']) branch_depths.append(unit['branch_depth']) linkage_types.append(unit['linkage_type']) else: # Apply BPE to this unit bpe_tokens = self.apply_bpe(unit['text']) for tok in bpe_tokens: tokens.append(tok) token_ids.append(self.token_to_id.get(tok, self.token_to_id['[UNK]'])) residue_ids.append(unit['residue_id']) branch_depths.append(unit['branch_depth']) linkage_types.append(unit['linkage_type']) # Truncate if needed if len(tokens) > max_length: tokens = tokens[:max_length - 1] + ['[END]'] token_ids = token_ids[:max_length - 1] + [self.token_to_id['[END]']] residue_ids = residue_ids[:max_length - 1] + [-1] branch_depths = branch_depths[:max_length - 1] + [0] linkage_types = linkage_types[:max_length - 1] + [0] # Create attention mask and pad length = len(tokens) attention_mask = [1] * length while len(tokens) < max_length: tokens.append('[PAD]') token_ids.append(self.token_to_id['[PAD]']) residue_ids.append(-1) branch_depths.append(0) linkage_types.append(0) attention_mask.append(0) # Pre-tokenize units = self.pre_tokenizer.pre_tokenize(wurcs) # Extract distance matrix from pre-tokenizer result dist_matrix_raw = units[0].get('distance_matrix', []) num_residues = len(dist_matrix_raw) # Map token-to-token distances using residue_ids # token_i is associated with residue_ids[i]. # residue_ids[i] is index into dist_matrix_raw. # If residue_ids[i] == -1 (special), distance is undefined (use -1 or 999) # Use UNPADDED length for distance matrix to save massive memory # distance_matrix will be e.g. 20x20, while tokens are padded to 256 token_len = length distance_matrix = [[-1] * token_len for _ in range(token_len)] for i in range(token_len): for j in range(token_len): r_i = residue_ids[i] r_j = residue_ids[j] if r_i >= 0 and r_j >= 0 and r_i < num_residues and r_j < num_residues: distance_matrix[i][j] = dist_matrix_raw[r_i][r_j] else: distance_matrix[i][j] = -1 # Special/Padding # MEMORY OPTIMIZATION: Do NOT pad matrix here. # Pad on-the-fly in Dataset class instead. # This saves massive memory (0.2GB vs 66GB). return { 'tokens': tokens, 'token_ids': token_ids, 'residue_ids': residue_ids, 'branch_depths': branch_depths, 'linkage_types': linkage_types, 'attention_mask': attention_mask, 'distance_matrix': distance_matrix, # New Output 'length': length, } def save_vocab(self, path: str): """Save vocabulary to JSON file.""" data = { 'special_tokens': self.special_tokens, 'token_to_id': self.token_to_id, 'merges': [list(m) for m in self.merges], 'metadata': { 'vocab_size': len(self.token_to_id), 'num_merges': len(self.merges), } } with open(path, 'w') as f: json.dump(data, f, indent=2) print(f"Saved vocabulary to {path}") def load_vocab(self, path: str): """Load vocabulary from JSON file.""" with open(path, 'r') as f: data = json.load(f) self.special_tokens = data['special_tokens'] self.token_to_id = data['token_to_id'] self.id_to_token = {int(v): k for k, v in self.token_to_id.items()} self.merges = [tuple(m) for m in data['merges']] print(f"Loaded vocabulary with {len(self.token_to_id)} tokens") @property def vocab_size(self) -> int: return len(self.token_to_id) # ============================================================================ # Testing # ============================================================================ if __name__ == '__main__': # Test pre-tokenizer print("="*80) print("Testing WURCSPreTokenizer") print("="*80) pre_tok = WURCSPreTokenizer() test_wurcs = [ "WURCS=2.0/2,2,1/[a2122h-1b_1-5][a2211m-1a_1-5]/1-2/a4-b1", "WURCS=2.0/3,3,2/[a2122h-1b_1-5_2*NCC/3=O][a2112h-1a_1-5][a2211m-1a_1-5]/1-2-3/a4-b1_b3-c1", ] for wurcs in test_wurcs: print(f"\nWURCS: {wurcs[:60]}...") units = pre_tok.pre_tokenize(wurcs) print(f"Units ({len(units)}):") for u in units[:10]: print(f" {u['type']:10} | res={u['residue_id']:2} | depth={u['branch_depth']} | {u['text']}") if len(units) > 10: print(f" ... and {len(units) - 10} more")