| |
| """ |
| WURCS-BPE Tokenizer |
| |
| A hybrid tokenizer that learns semantic subwords from WURCS while preserving |
| the ability to handle rare/novel glycan structures character-by-character. |
| |
| Key features: |
| 1. Pre-tokenization: Split WURCS into semantic units (residues, linkages, mods) |
| 2. BPE: Learn subword merges from corpus |
| 3. Character fallback: Handle novel structures |
| 4. Tree embeddings: Preserve branch_depth and linkage_type per token |
| |
| Usage: |
| # Train BPE on corpus |
| tokenizer = WURCSBPETokenizer.train_from_corpus( |
| wurcs_strings, |
| num_merges=500, |
| output_path="bpe_vocabulary.json" |
| ) |
| |
| # Tokenize |
| result = tokenizer.tokenize(wurcs_string) |
| """ |
|
|
| import json |
| import re |
| from collections import Counter, defaultdict |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple, Set |
| import pickle |
|
|
|
|
| class WURCSPreTokenizer: |
| """ |
| Pre-tokenize WURCS into semantic units before BPE. |
| |
| WURCS format: WURCS=2.0/count/[residues]/indices/linkages |
| |
| We split into: |
| - Residues: [a2122h-1b_1-5_2*NCC/3=O] -> one unit per [] |
| - Linkages: a4-b1 -> one unit per linkage |
| - Special markers: [BRANCH_OPEN], [BRANCH_CLOSE], etc. |
| """ |
| |
| |
| RESIDUE_PATTERN = re.compile(r'\[([^\]]+)\]') |
| LINKAGE_PATTERN = re.compile(r'([a-z])(\d+|\?)-([a-z])(\d+|\?)') |
| |
| def __init__(self): |
| self.special_tokens = { |
| '[PAD]': 0, |
| '[UNK]': 1, |
| '[START]': 2, |
| '[END]': 3, |
| '[MASK]': 4, |
| '[BRANCH_OPEN]': 5, |
| '[BRANCH_CLOSE]': 6, |
| '[LINK]': 7, |
| '[MOD]': 8, |
| '[RESIDUE_ERROR]': 9, |
| } |
| |
| def pre_tokenize(self, wurcs: str) -> List[Dict]: |
| """ |
| Pre-tokenize WURCS into semantic units. |
| |
| Returns list of dicts with: |
| - text: The unit text |
| - type: 'special', 'residue', 'linkage', 'mod', 'index' |
| - residue_id: Which residue this belongs to (-1 for special, -2 for linkage) |
| - branch_depth: Tree depth (computed later) |
| """ |
| units = [] |
| |
| |
| units.append({ |
| 'text': '[START]', |
| 'type': 'special', |
| 'residue_id': -1, |
| 'branch_depth': 0, |
| 'linkage_type': 0, |
| }) |
| |
| |
| if not wurcs.startswith('WURCS='): |
| units.append({'text': '[RESIDUE_ERROR]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0}) |
| units.append({'text': '[END]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0}) |
| return units |
| |
| try: |
| parts = self._split_wurcs_sections(wurcs) |
| if len(parts) < 4: |
| return [{'text': '[ERROR]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0}] |
| |
| |
| |
| |
| |
| |
| version = parts[0] |
| counts = parts[1] |
| residue_defs = parts[2] |
| indices = parts[3] |
| linkages = parts[4] if len(parts) > 4 else "" |
| |
| |
| residue_list = self.RESIDUE_PATTERN.findall(residue_defs) |
| |
| |
| linkage_list = linkages.split('_') if linkages else [] |
| branch_points, residue_depths, linkage_types_map, adj = self._analyze_tree_structure(linkage_list, num_residues=len(residue_list)) |
| |
| |
| |
| if not hasattr(self, '_dist_cache'): self._dist_cache = {} |
| if linkages not in self._dist_cache: |
| self._dist_cache[linkages] = self._compute_distance_matrix(adj, len(residue_list)) |
| dist_matrix_raw = self._dist_cache[linkages] |
| |
| |
| index_list = indices.split('-') if indices else [] |
| |
| |
| residue_letter = ord('a') |
| for idx, res_idx in enumerate(index_list): |
| current_residue_id = idx |
| res_letter = chr(residue_letter + idx) |
| |
| |
| if res_letter in branch_points and branch_points[res_letter] > 0: |
| for _ in range(branch_points[res_letter]): |
| units.append({ |
| 'text': '[BRANCH_OPEN]', |
| 'type': 'special', |
| 'residue_id': -1, |
| 'branch_depth': residue_depths.get(res_letter, 0), |
| 'linkage_type': 0, |
| }) |
| |
| |
| try: |
| res_def_idx = int(res_idx) - 1 |
| res_def = residue_list[res_def_idx] if res_def_idx < len(residue_list) else "" |
| except (ValueError, IndexError): |
| res_def = "" |
| |
| |
| res_parts = res_def.split('_') |
| base = res_parts[0] if res_parts else res_def |
| mods = res_parts[1:] if len(res_parts) > 1 else [] |
| |
| |
| depth = residue_depths.get(res_letter, 0) |
| units.append({ |
| 'text': base, |
| 'type': 'residue', |
| 'residue_id': current_residue_id, |
| 'branch_depth': depth, |
| 'linkage_type': 0, |
| }) |
| |
| |
| for mod in mods: |
| units.append({ |
| 'text': mod, |
| 'type': 'mod', |
| 'residue_id': current_residue_id, |
| 'branch_depth': depth, |
| 'linkage_type': 0, |
| }) |
| |
| |
| if units: |
| |
| units[0]['distance_matrix'] = dist_matrix_raw |
| |
| |
| for link in linkage_list: |
| if not link: |
| continue |
| |
| lt = self._parse_linkage_type(link) |
| units.append({ |
| 'text': link, |
| 'type': 'linkage', |
| 'residue_id': -2, |
| 'branch_depth': 0, |
| 'linkage_type': lt, |
| }) |
| except Exception: |
| |
| pass |
| |
| |
| units.append({ |
| 'text': '[END]', |
| 'type': 'special', |
| 'residue_id': -1, |
| 'branch_depth': 0, |
| 'linkage_type': 0, |
| }) |
| |
| return units |
| |
| def _split_wurcs_sections(self, wurcs: str) -> List[str]: |
| """Split WURCS string into sections, handling nested brackets.""" |
| |
| if wurcs.startswith('WURCS='): |
| wurcs = wurcs[6:] |
| |
| sections = [] |
| current = "" |
| bracket_depth = 0 |
| |
| for char in wurcs: |
| if char == '[': |
| bracket_depth += 1 |
| current += char |
| elif char == ']': |
| bracket_depth -= 1 |
| current += char |
| elif char == '/' and bracket_depth == 0: |
| sections.append(current) |
| current = "" |
| else: |
| current += char |
| |
| if current: |
| sections.append(current) |
| |
| return sections |
| |
| def _analyze_tree_structure(self, linkages: List[str], num_residues: int) -> Tuple[Dict, Dict, Dict, Dict]: |
| """Analyze linkages to compute branch points and residue depths.""" |
| branch_points = defaultdict(int) |
| children = defaultdict(list) |
| all_residues = set() |
| linkage_types = {} |
| |
| for link in linkages: |
| match = self.LINKAGE_PATTERN.match(link) |
| if match: |
| from_res, from_pos, to_res, to_pos = match.groups() |
| children[from_res].append(to_res) |
| all_residues.add(from_res) |
| all_residues.add(to_res) |
| |
| |
| linkage_types[link] = self._parse_linkage_type(link) |
| |
| |
| adj = defaultdict(list) |
| for link in linkages: |
| match = self.LINKAGE_PATTERN.match(link) |
| if match: |
| u = ord(match.group(1)) - ord('a') |
| v = ord(match.group(3)) - ord('a') |
| if 0 <= u < num_residues and 0 <= v < num_residues: |
| adj[u].append(v) |
| adj[v].append(u) |
| |
| |
| for res, kids in children.items(): |
| if len(kids) > 1: |
| branch_points[res] = len(kids) - 1 |
| |
| |
| |
| child_set = set() |
| for kids in children.values(): |
| child_set.update(kids) |
| roots = all_residues - child_set |
| root = min(roots) if roots else 'a' |
| |
| depths = {root: 0} |
| queue = [root] |
| while queue: |
| current = queue.pop(0) |
| for child in children.get(current, []): |
| if child not in depths: |
| depths[child] = depths[current] + 1 |
| queue.append(child) |
| |
| return branch_points, depths, linkage_types, adj |
| |
| def _compute_distance_matrix(self, adj: Dict[int, List[int]], num_residues: int) -> List[List[int]]: |
| """ |
| Compute shortest path distance (number of bonds) between all residue pairs using BFS. |
| """ |
| if num_residues == 0: |
| return [] |
| |
| dist_matrix = [[-1] * num_residues for _ in range(num_residues)] |
| |
| for i in range(num_residues): |
| dist_matrix[i][i] = 0 |
| queue = [(i, 0)] |
| visited = {i} |
| |
| while queue: |
| curr, d = queue.pop(0) |
| dist_matrix[i][curr] = d |
| |
| for neighbor in adj[curr]: |
| if neighbor not in visited: |
| visited.add(neighbor) |
| queue.append((neighbor, d + 1)) |
| |
| return dist_matrix |
| |
| def _compute_distance_matrix_OLD(self, linkages: List[str], num_residues: int) -> List[List[int]]: |
| """ |
| Compute shortest path distance (number of bonds) between all residue pairs. |
| Returns a symmetric N x N matrix where N is num_residues. |
| Values are integers (number of steps). 0 on diagonal. -1 if unreachable (shouldn't happen in single tree). |
| """ |
| if num_residues == 0: |
| return [] |
| |
| |
| adj = defaultdict(list) |
| for link in linkages: |
| match = self.LINKAGE_PATTERN.match(link) |
| if match: |
| |
| from_res_char, _, to_res_char, _ = match.groups() |
| |
| u = ord(from_res_char) - ord('a') |
| v = ord(to_res_char) - ord('a') |
| |
| |
| if 0 <= u < num_residues and 0 <= v < num_residues: |
| adj[u].append(v) |
| adj[v].append(u) |
| |
| |
| |
| dist_matrix = [[-1] * num_residues for _ in range(num_residues)] |
| |
| for i in range(num_residues): |
| dist_matrix[i][i] = 0 |
| queue = [(i, 0)] |
| visited = {i} |
| |
| while queue: |
| curr, d = queue.pop(0) |
| dist_matrix[i][curr] = d |
| |
| for neighbor in adj[curr]: |
| if neighbor not in visited: |
| visited.add(neighbor) |
| queue.append((neighbor, d + 1)) |
| |
| return dist_matrix |
| |
| def _parse_linkage_type(self, link: str) -> int: |
| """Parse linkage string to get type ID.""" |
| LINKAGE_TYPES = { |
| (1, 2): 0, (2, 1): 0, |
| (1, 3): 1, (3, 1): 1, |
| (1, 4): 2, (4, 1): 2, |
| (1, 6): 3, (6, 1): 3, |
| (2, 3): 4, (3, 2): 4, |
| (2, 6): 5, (6, 2): 5, |
| (3, 6): 6, (6, 3): 6, |
| } |
| |
| match = self.LINKAGE_PATTERN.match(link) |
| if match: |
| _, from_pos, _, to_pos = match.groups() |
| try: |
| pos_tuple = (int(from_pos), int(to_pos)) |
| return LINKAGE_TYPES.get(pos_tuple, 7) |
| except ValueError: |
| return 8 |
| return 8 |
|
|
|
|
| class WURCSBPETokenizer: |
| """ |
| BPE tokenizer for WURCS with tree-aware embeddings. |
| """ |
| |
| def __init__(self, vocab_path: Optional[str] = None): |
| self.pre_tokenizer = WURCSPreTokenizer() |
| |
| |
| self.special_tokens = self.pre_tokenizer.special_tokens |
| |
| |
| self.token_to_id: Dict[str, int] = {} |
| self.id_to_token: Dict[int, str] = {} |
| self.merges: List[Tuple[str, str]] = [] |
| |
| if vocab_path: |
| self.load_vocab(vocab_path) |
| else: |
| |
| self.token_to_id = dict(self.special_tokens) |
| self.id_to_token = {v: k for k, v in self.token_to_id.items()} |
| |
| @classmethod |
| def train_from_corpus( |
| cls, |
| wurcs_strings: List[str], |
| num_merges: int = 500, |
| output_path: Optional[str] = None, |
| min_frequency: int = 2, |
| max_token_length: Optional[int] = None, |
| ) -> 'WURCSBPETokenizer': |
| """ |
| Train BPE on a corpus of WURCS strings. |
| |
| Args: |
| wurcs_strings: List of WURCS strings |
| num_merges: Number of BPE merge operations |
| output_path: Optional path to save vocabulary |
| min_frequency: Minimum frequency for a token to be kept |
| max_token_length: Maximum length of a merged token (None = no limit) |
| |
| Returns: |
| Trained tokenizer |
| """ |
| tokenizer = cls() |
| pre_tok = WURCSPreTokenizer() |
| |
| print(f"Training BPE on {len(wurcs_strings)} WURCS strings...") |
| |
| |
| all_units = [] |
| for wurcs in wurcs_strings: |
| units = pre_tok.pre_tokenize(wurcs) |
| for unit in units: |
| if unit['type'] != 'special': |
| all_units.append(unit['text']) |
| |
| |
| unit_counts = Counter(all_units) |
| print(f"Found {len(unit_counts)} unique units") |
| |
| |
| char_vocab = set() |
| for unit in unit_counts: |
| for char in unit: |
| char_vocab.add(char) |
| |
| |
| vocab_id = len(tokenizer.special_tokens) |
| for char in sorted(char_vocab): |
| tokenizer.token_to_id[char] = vocab_id |
| tokenizer.id_to_token[vocab_id] = char |
| vocab_id += 1 |
| |
| print(f"Initial vocab size: {vocab_id} (special + characters)") |
| |
| |
| word_freqs = {} |
| for unit, count in unit_counts.items(): |
| if count >= min_frequency: |
| |
| chars = tuple(unit) |
| word_freqs[chars] = count |
| |
| |
| merges = [] |
| |
| for merge_idx in range(num_merges): |
| |
| pair_counts = Counter() |
| for word, freq in word_freqs.items(): |
| for i in range(len(word) - 1): |
| pair = (word[i], word[i + 1]) |
| pair_counts[pair] += freq |
| |
| if not pair_counts: |
| break |
| |
| |
| best_pair = pair_counts.most_common(1)[0][0] |
| best_count = pair_counts[best_pair] |
| |
| if best_count < min_frequency: |
| break |
| |
| |
| new_token = best_pair[0] + best_pair[1] |
| |
| |
| if max_token_length and len(new_token) > max_token_length: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| found_valid_pair = False |
| for pair, count in pair_counts.most_common(): |
| token_candidate = pair[0] + pair[1] |
| if max_token_length and len(token_candidate) > max_token_length: |
| continue |
| |
| if count < min_frequency: |
| break |
| |
| |
| best_pair = pair |
| best_count = count |
| new_token = token_candidate |
| found_valid_pair = True |
| break |
| |
| if not found_valid_pair: |
| print(f" Stopping early: No more pairs satisfy max_token_length={max_token_length}") |
| break |
| |
| |
| |
| |
| merges.append(best_pair) |
| |
| |
| tokenizer.token_to_id[new_token] = vocab_id |
| tokenizer.id_to_token[vocab_id] = new_token |
| vocab_id += 1 |
| |
| |
| new_word_freqs = {} |
| for word, freq in word_freqs.items(): |
| new_word = [] |
| i = 0 |
| while i < len(word): |
| if i < len(word) - 1 and word[i] == best_pair[0] and word[i + 1] == best_pair[1]: |
| new_word.append(new_token) |
| i += 2 |
| else: |
| new_word.append(word[i]) |
| i += 1 |
| new_word_freqs[tuple(new_word)] = freq |
| word_freqs = new_word_freqs |
| |
| if (merge_idx + 1) % 100 == 0: |
| print(f" Merge {merge_idx + 1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' -> '{new_token}' (count={best_count})") |
| |
| tokenizer.merges = merges |
| print(f"Final vocab size: {len(tokenizer.token_to_id)}") |
| |
| |
| if output_path: |
| tokenizer.save_vocab(output_path) |
| |
| return tokenizer |
| |
| def apply_bpe(self, text: str) -> List[str]: |
| """Apply BPE merges to a text string.""" |
| if text in self.token_to_id: |
| return [text] |
| |
| |
| tokens = list(text) |
| |
| |
| for pair in self.merges: |
| new_tokens = [] |
| i = 0 |
| while i < len(tokens): |
| if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]: |
| new_tokens.append(pair[0] + pair[1]) |
| i += 2 |
| else: |
| new_tokens.append(tokens[i]) |
| i += 1 |
| tokens = new_tokens |
| |
| return tokens |
| |
| def tokenize(self, wurcs: str, max_length: int = 256) -> Dict: |
| """ |
| Tokenize a WURCS string. |
| |
| Returns: |
| Dict with: |
| - tokens: List of token strings |
| - token_ids: List of token IDs |
| - residue_ids: List of residue IDs |
| - branch_depths: List of branch depths |
| - linkage_types: List of linkage types |
| - attention_mask: Attention mask |
| """ |
| |
| units = self.pre_tokenizer.pre_tokenize(wurcs) |
| |
| tokens = [] |
| token_ids = [] |
| residue_ids = [] |
| branch_depths = [] |
| linkage_types = [] |
| |
| for unit in units: |
| if unit['type'] == 'special': |
| |
| tok = unit['text'] |
| tokens.append(tok) |
| token_ids.append(self.token_to_id.get(tok, self.token_to_id['[UNK]'])) |
| residue_ids.append(unit['residue_id']) |
| branch_depths.append(unit['branch_depth']) |
| linkage_types.append(unit['linkage_type']) |
| else: |
| |
| bpe_tokens = self.apply_bpe(unit['text']) |
| for tok in bpe_tokens: |
| tokens.append(tok) |
| token_ids.append(self.token_to_id.get(tok, self.token_to_id['[UNK]'])) |
| residue_ids.append(unit['residue_id']) |
| branch_depths.append(unit['branch_depth']) |
| linkage_types.append(unit['linkage_type']) |
| |
| |
| if len(tokens) > max_length: |
| tokens = tokens[:max_length - 1] + ['[END]'] |
| token_ids = token_ids[:max_length - 1] + [self.token_to_id['[END]']] |
| residue_ids = residue_ids[:max_length - 1] + [-1] |
| branch_depths = branch_depths[:max_length - 1] + [0] |
| linkage_types = linkage_types[:max_length - 1] + [0] |
| |
| |
| length = len(tokens) |
| attention_mask = [1] * length |
| |
| while len(tokens) < max_length: |
| tokens.append('[PAD]') |
| token_ids.append(self.token_to_id['[PAD]']) |
| residue_ids.append(-1) |
| branch_depths.append(0) |
| linkage_types.append(0) |
| attention_mask.append(0) |
| |
| units = self.pre_tokenizer.pre_tokenize(wurcs) |
| |
| |
| dist_matrix_raw = units[0].get('distance_matrix', []) |
| num_residues = len(dist_matrix_raw) |
| |
| |
| |
| |
| |
| |
| |
| |
| token_len = length |
| distance_matrix = [[-1] * token_len for _ in range(token_len)] |
| |
| for i in range(token_len): |
| for j in range(token_len): |
| r_i = residue_ids[i] |
| r_j = residue_ids[j] |
| |
| if r_i >= 0 and r_j >= 0 and r_i < num_residues and r_j < num_residues: |
| distance_matrix[i][j] = dist_matrix_raw[r_i][r_j] |
| else: |
| distance_matrix[i][j] = -1 |
| |
| |
| |
| |
| |
| return { |
| 'tokens': tokens, |
| 'token_ids': token_ids, |
| 'residue_ids': residue_ids, |
| 'branch_depths': branch_depths, |
| 'linkage_types': linkage_types, |
| 'attention_mask': attention_mask, |
| 'distance_matrix': distance_matrix, |
| 'length': length, |
| } |
| |
| def save_vocab(self, path: str): |
| """Save vocabulary to JSON file.""" |
| data = { |
| 'special_tokens': self.special_tokens, |
| 'token_to_id': self.token_to_id, |
| 'merges': [list(m) for m in self.merges], |
| 'metadata': { |
| 'vocab_size': len(self.token_to_id), |
| 'num_merges': len(self.merges), |
| } |
| } |
| with open(path, 'w') as f: |
| json.dump(data, f, indent=2) |
| print(f"Saved vocabulary to {path}") |
| |
| def load_vocab(self, path: str): |
| """Load vocabulary from JSON file.""" |
| with open(path, 'r') as f: |
| data = json.load(f) |
| |
| self.special_tokens = data['special_tokens'] |
| self.token_to_id = data['token_to_id'] |
| self.id_to_token = {int(v): k for k, v in self.token_to_id.items()} |
| self.merges = [tuple(m) for m in data['merges']] |
| |
| print(f"Loaded vocabulary with {len(self.token_to_id)} tokens") |
| |
| @property |
| def vocab_size(self) -> int: |
| return len(self.token_to_id) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| |
| print("="*80) |
| print("Testing WURCSPreTokenizer") |
| print("="*80) |
| |
| pre_tok = WURCSPreTokenizer() |
| |
| test_wurcs = [ |
| "WURCS=2.0/2,2,1/[a2122h-1b_1-5][a2211m-1a_1-5]/1-2/a4-b1", |
| "WURCS=2.0/3,3,2/[a2122h-1b_1-5_2*NCC/3=O][a2112h-1a_1-5][a2211m-1a_1-5]/1-2-3/a4-b1_b3-c1", |
| ] |
| |
| for wurcs in test_wurcs: |
| print(f"\nWURCS: {wurcs[:60]}...") |
| units = pre_tok.pre_tokenize(wurcs) |
| print(f"Units ({len(units)}):") |
| for u in units[:10]: |
| print(f" {u['type']:10} | res={u['residue_id']:2} | depth={u['branch_depth']} | {u['text']}") |
| if len(units) > 10: |
| print(f" ... and {len(units) - 10} more") |
|
|