| """ |
| Standalone WURCS Tokenizer for Fine-tuning |
| |
| This tokenizer matches the v3.1 pre-training tokenization exactly: |
| - Uses bracket-counting parser (handles '/' in modifications) |
| - 167 tokens (including 58 whole-token modifications) |
| - Same special tokens: [PAD], [UNK], [START], [END], [BRANCH_OPEN], [BRANCH_CLOSE], etc. |
| - Detects and inserts branch tokens based on WURCS topology |
| |
| Usage: |
| tokenizer = WURCSTokenizer('data/vocabulary.json') |
| result = tokenizer.tokenize(wurcs_string) |
| # result = {'tokens': [...], 'token_ids': [...], 'length': N} |
| """ |
|
|
| import json |
| import re |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
| from collections import defaultdict |
|
|
|
|
| def parse_wurcs_sections(wurcs: str) -> List[str]: |
| """ |
| Parse WURCS string into sections by counting brackets. |
| |
| This correctly handles modifications with '/' characters like NCC/3=O. |
| The simple regex approach fails because '/' appears inside modifications. |
| |
| Args: |
| wurcs: WURCS string |
| |
| Returns: |
| List of sections [version, counts, residues, topology, linkages] |
| """ |
| sections = [] |
| current_section = [] |
| bracket_depth = 0 |
| |
| for char in wurcs: |
| if char == '[': |
| bracket_depth += 1 |
| current_section.append(char) |
| elif char == ']': |
| bracket_depth -= 1 |
| current_section.append(char) |
| elif char == '/' and bracket_depth == 0: |
| |
| sections.append(''.join(current_section)) |
| current_section = [] |
| else: |
| current_section.append(char) |
| |
| |
| if current_section: |
| sections.append(''.join(current_section)) |
| |
| return sections |
|
|
|
|
| def extract_residues_from_wurcs(wurcs: str) -> List[str]: |
| """ |
| Extract residue definitions from WURCS string. |
| |
| Uses bracket-counting parser to correctly handle modifications |
| containing '/' characters (e.g., 2*NCC/3=O, 4*OSO/3=O/3=O). |
| |
| Args: |
| wurcs: WURCS string |
| |
| Returns: |
| List of residue strings (e.g., ['a2122h-1a_1-5_2*NCC/3=O', ...]) |
| """ |
| sections = parse_wurcs_sections(wurcs) |
| |
| if len(sections) < 3: |
| return [] |
| |
| |
| residues_section = sections[2] |
| |
| |
| residues = re.findall(r'\[([^\]]+)\]', residues_section) |
| |
| return residues |
|
|
|
|
| def extract_linkages_from_wurcs(wurcs: str) -> List[str]: |
| """ |
| Extract linkage patterns from WURCS string. |
| |
| Uses bracket-counting parser to correctly handle modifications |
| containing '/' characters. |
| |
| Args: |
| wurcs: WURCS string |
| |
| Returns: |
| List of linkage strings (e.g., ['a4-b1', 'b3-c2']) |
| """ |
| sections = parse_wurcs_sections(wurcs) |
| |
| if len(sections) < 5: |
| return [] |
| |
| |
| linkage_section = sections[4] |
| |
| |
| linkages = linkage_section.split('_') |
| |
| return [l for l in linkages if l] |
|
|
|
|
| def detect_branch_points(linkages: List[str]) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]: |
| """ |
| Detect branch points from linkages. |
| A residue is a branch point if it has >1 outgoing connections. |
| |
| Args: |
| linkages: List of linkage strings (e.g., ['a4-b1', 'a3-c2']) |
| |
| Returns: |
| Tuple of: |
| - branch_points: Dict mapping residue letter to list of children (only for >1 children) |
| - outgoing: Dict mapping all residue letters to their children |
| """ |
| outgoing = defaultdict(list) |
| |
| for link in linkages: |
| |
| match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link) |
| if match: |
| from_res, from_pos, to_res, to_pos = match.groups() |
| outgoing[from_res].append(to_res) |
| |
| |
| branch_points = {k: v for k, v in outgoing.items() if len(v) > 1} |
| return branch_points, dict(outgoing) |
|
|
|
|
| def compute_residue_depths(linkages: List[str]) -> Dict[str, int]: |
| """ |
| Compute depth of each residue in the glycan tree using BFS from root. |
| |
| Root is typically 'a' (first residue). Depth 0 = root, 1 = children of root, etc. |
| |
| Args: |
| linkages: List of linkage strings (e.g., ['a4-b1', 'b3-c1']) |
| |
| Returns: |
| Dict mapping residue letter to its depth in the tree |
| """ |
| |
| children = defaultdict(list) |
| all_residues = set() |
| |
| for link in linkages: |
| match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link) |
| if match: |
| from_res, _, to_res, _ = match.groups() |
| children[from_res].append(to_res) |
| all_residues.add(from_res) |
| all_residues.add(to_res) |
| |
| |
| child_set = set() |
| for kids in children.values(): |
| child_set.update(kids) |
| |
| roots = all_residues - child_set |
| root = min(roots) if roots else 'a' |
| |
| |
| depths = {root: 0} |
| queue = [root] |
| |
| while queue: |
| current = queue.pop(0) |
| for child in children.get(current, []): |
| if child not in depths: |
| depths[child] = depths[current] + 1 |
| queue.append(child) |
| |
| return depths |
|
|
|
|
| def parse_linkage_type(link: str) -> int: |
| """ |
| Parse a linkage string to extract linkage type ID. |
| |
| Linkage types based on anomeric configuration and position: |
| - α1-2, α1-3, α1-4, α1-6, β1-2, β1-3, β1-4, β1-6, unknown |
| |
| WURCS uses position numbers (1-6) in linkage patterns like 'a4-b1'. |
| We extract the positions to determine linkage type. |
| |
| Args: |
| link: Linkage string like 'a4-b1' or 'a3-b1' |
| |
| Returns: |
| Linkage type ID (0-8): |
| 0 = 1-2, 1 = 1-3, 2 = 1-4, 3 = 1-6, 4 = 2-3, 5 = 2-6, 6 = 3-6, 7 = other, 8 = unknown |
| """ |
| LINKAGE_TYPES = { |
| (1, 2): 0, (2, 1): 0, |
| (1, 3): 1, (3, 1): 1, |
| (1, 4): 2, (4, 1): 2, |
| (1, 6): 3, (6, 1): 3, |
| (2, 3): 4, (3, 2): 4, |
| (2, 6): 5, (6, 2): 5, |
| (3, 6): 6, (6, 3): 6, |
| } |
| |
| match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link) |
| if match: |
| _, from_pos, _, to_pos = match.groups() |
| try: |
| pos_tuple = (int(from_pos), int(to_pos)) |
| return LINKAGE_TYPES.get(pos_tuple, 7) |
| except ValueError: |
| return 8 |
| return 8 |
|
|
|
|
| class WURCSTokenizer: |
| """ |
| WURCS tokenizer matching v3.1 pre-training exactly. |
| |
| Key features: |
| - Bracket-counting parser for '/' handling in modifications |
| - Branch token insertion based on topology |
| - Whole-token modifications (58 total) |
| """ |
| |
| def __init__(self, vocab_path: str): |
| """ |
| Initialize tokenizer with vocabulary. |
| |
| Args: |
| vocab_path: Path to vocabulary.json (same as used in v3.1 pre-training) |
| """ |
| with open(vocab_path, 'r') as f: |
| self.vocab = json.load(f) |
| |
| |
| self.token_to_id = {} |
| |
| |
| for token, idx in self.vocab.get('special_tokens', {}).items(): |
| self.token_to_id[token] = idx |
| |
| |
| for token, idx in self.vocab.get('skeleton_atoms', {}).items(): |
| self.token_to_id[token] = idx |
| |
| |
| for token, idx in self.vocab.get('linkage_atoms', {}).items(): |
| self.token_to_id[token] = idx |
| |
| |
| for token, idx in self.vocab.get('anomeric_symbols', {}).items(): |
| self.token_to_id[token] = idx |
| |
| |
| for token, idx in self.vocab.get('anomeric_positions', {}).items(): |
| self.token_to_id[token] = idx |
| |
| |
| for token, idx in self.vocab.get('whole_modifications', {}).items(): |
| self.token_to_id[token] = idx |
| |
| |
| self.id_to_token = {idx: token for token, idx in self.token_to_id.items()} |
| |
| |
| self.pad_token_id = self.token_to_id.get('[PAD]', 0) |
| self.unk_token_id = self.token_to_id.get('[UNK]', 1) |
| self.start_token_id = self.token_to_id.get('[START]', 2) |
| self.end_token_id = self.token_to_id.get('[END]', 3) |
| self.branch_open_id = self.token_to_id.get('[BRANCH_OPEN]', 5) |
| self.branch_close_id = self.token_to_id.get('[BRANCH_CLOSE]', 6) |
| self.mod_token_id = self.token_to_id.get('[MOD]', 9) |
| self.link_token_id = self.token_to_id.get('[LINK]', 7) |
| |
| self.vocab_size = self.vocab.get('metadata', {}).get('total_vocab_size', len(self.token_to_id)) |
| |
| def tokenize(self, wurcs: str, max_length: int = 512) -> Dict: |
| """ |
| Tokenize a WURCS string. |
| |
| Matches v3.1 pre-training tokenization exactly: |
| - Bracket-counting parser for '/' handling |
| - Branch token insertion based on topology |
| - Whole-token modifications |
| |
| Args: |
| wurcs: WURCS string to tokenize |
| max_length: Maximum sequence length (truncate if longer) |
| |
| Returns: |
| Dict with: |
| - tokens: List of token strings |
| - token_ids: List of token IDs |
| - residue_ids: List of residue IDs for each token (-1=special, -2=linkage) |
| - monosaccharide_names: List of monosaccharide names for each residue |
| - length: Number of tokens |
| - attention_mask: 1 for real tokens, 0 for padding |
| - is_branched: Whether the glycan has branches |
| """ |
| all_tokens = [] |
| all_token_ids = [] |
| all_residue_ids = [] |
| all_branch_depths = [] |
| all_linkage_types = [] |
| monosaccharide_names = [] |
| is_branched = False |
| residue_depths = {} |
| |
| |
| MONO_PATTERNS = { |
| 'a2122h': 'Glc', 'a2112h': 'Gal', 'a1221m': 'Fuc', 'a2211m': 'Rha', |
| 'a212h': 'Xyl', 'a21d2h': 'Man', 'a2112m': 'Ara', 'a2d21h': 'Ido', |
| 'axxxxh': 'Hex', 'Aad21122h': 'Neu5Ac', 'Aad21122h-2': 'Neu5Gc', |
| } |
| |
| |
| all_tokens.append('[START]') |
| all_token_ids.append(self.start_token_id) |
| all_residue_ids.append(-1) |
| all_branch_depths.append(0) |
| all_linkage_types.append(0) |
| |
| try: |
| |
| residues = extract_residues_from_wurcs(wurcs) |
| linkages = extract_linkages_from_wurcs(wurcs) |
| |
| |
| branch_points, outgoing = detect_branch_points(linkages) |
| is_branched = len(branch_points) > 0 |
| |
| |
| residue_depths = compute_residue_depths(linkages) |
| |
| |
| sections = parse_wurcs_sections(wurcs) |
| topology = sections[3] if len(sections) > 3 else "" |
| residue_order = [int(x) for x in topology.split('-') if x.isdigit()] |
| |
| |
| residue_letters = 'abcdefghijklmnopqrstuvwxyz' |
| current_residue_id = 0 |
| |
| for i, res_idx in enumerate(residue_order): |
| if res_idx < 1 or res_idx > len(residues): |
| continue |
| |
| residue = residues[res_idx - 1] |
| res_letter = residue_letters[i] if i < len(residue_letters) else 'z' |
| |
| |
| mono_name = '<UNK>' |
| parts = residue.split('_') |
| if parts: |
| skeleton = parts[0].split('-')[0] if '-' in parts[0] else parts[0] |
| |
| has_nac = any('NCC/3=O' in p or 'NAc' in p for p in parts) |
| has_s = any('OSO' in p for p in parts) |
| |
| for pattern, name in MONO_PATTERNS.items(): |
| if skeleton.startswith(pattern) or skeleton == pattern: |
| mono_name = name |
| break |
| |
| |
| if has_nac and mono_name in ['Glc', 'Gal']: |
| mono_name = mono_name + 'NAc' |
| elif has_s: |
| mono_name = mono_name + 'S' |
| |
| monosaccharide_names.append(mono_name) |
| |
| |
| |
| res_depth = residue_depths.get(res_letter, 0) |
| |
| if res_letter in branch_points: |
| all_tokens.append('[BRANCH_OPEN]') |
| all_token_ids.append(self.branch_open_id) |
| all_residue_ids.append(-1) |
| all_branch_depths.append(res_depth) |
| all_linkage_types.append(0) |
| |
| |
| |
| if parts: |
| main_part = parts[0] |
| for char in main_part: |
| if char in self.token_to_id: |
| all_tokens.append(char) |
| all_token_ids.append(self.token_to_id[char]) |
| all_residue_ids.append(current_residue_id) |
| all_branch_depths.append(res_depth) |
| all_linkage_types.append(0) |
| elif char == '-': |
| if '-' in self.token_to_id: |
| all_tokens.append('-') |
| all_token_ids.append(self.token_to_id['-']) |
| all_residue_ids.append(current_residue_id) |
| all_branch_depths.append(res_depth) |
| all_linkage_types.append(0) |
| else: |
| all_tokens.append('[UNK]') |
| all_token_ids.append(self.unk_token_id) |
| all_residue_ids.append(current_residue_id) |
| all_branch_depths.append(res_depth) |
| all_linkage_types.append(0) |
| |
| |
| if len(parts) > 2: |
| for part in parts[2:]: |
| if '*' in part: |
| all_tokens.append('[MOD]') |
| all_token_ids.append(self.mod_token_id) |
| all_residue_ids.append(current_residue_id) |
| all_branch_depths.append(res_depth) |
| all_linkage_types.append(0) |
| |
| if part in self.token_to_id: |
| all_tokens.append(part) |
| all_token_ids.append(self.token_to_id[part]) |
| all_residue_ids.append(current_residue_id) |
| all_branch_depths.append(res_depth) |
| all_linkage_types.append(0) |
| else: |
| all_tokens.append('[UNK_MOD]') |
| all_token_ids.append(self.token_to_id.get('[UNK_MOD]', self.unk_token_id)) |
| all_residue_ids.append(current_residue_id) |
| all_branch_depths.append(res_depth) |
| all_linkage_types.append(0) |
| |
| current_residue_id += 1 |
| |
| |
| branch_opens = all_tokens.count('[BRANCH_OPEN]') |
| for _ in range(branch_opens): |
| all_tokens.append('[BRANCH_CLOSE]') |
| all_token_ids.append(self.branch_close_id) |
| all_residue_ids.append(-1) |
| all_branch_depths.append(0) |
| all_linkage_types.append(0) |
| |
| |
| for linkage in linkages: |
| if not linkage: |
| continue |
| |
| link_type = parse_linkage_type(linkage) |
| |
| all_tokens.append('[LINK]') |
| all_token_ids.append(self.link_token_id) |
| all_residue_ids.append(-2) |
| all_branch_depths.append(0) |
| all_linkage_types.append(link_type) |
| |
| for char in linkage: |
| if char in self.token_to_id: |
| all_tokens.append(char) |
| all_token_ids.append(self.token_to_id[char]) |
| all_residue_ids.append(-2) |
| all_branch_depths.append(0) |
| all_linkage_types.append(link_type) |
| else: |
| all_tokens.append('[UNK]') |
| all_token_ids.append(self.unk_token_id) |
| all_residue_ids.append(-2) |
| all_branch_depths.append(0) |
| all_linkage_types.append(link_type) |
| |
| except Exception: |
| pass |
| |
| |
| all_tokens.append('[END]') |
| all_token_ids.append(self.end_token_id) |
| all_residue_ids.append(-1) |
| all_branch_depths.append(0) |
| all_linkage_types.append(0) |
| |
| |
| if len(all_token_ids) > max_length: |
| all_tokens = all_tokens[:max_length-1] + ['[END]'] |
| all_token_ids = all_token_ids[:max_length-1] + [self.end_token_id] |
| all_residue_ids = all_residue_ids[:max_length-1] + [-1] |
| all_branch_depths = all_branch_depths[:max_length-1] + [0] |
| all_linkage_types = all_linkage_types[:max_length-1] + [0] |
| |
| length = len(all_token_ids) |
| attention_mask = [1] * length |
| |
| |
| padding_length = max_length - length |
| if padding_length > 0: |
| all_tokens = all_tokens + ['[PAD]'] * padding_length |
| all_token_ids = all_token_ids + [self.pad_token_id] * padding_length |
| all_residue_ids = all_residue_ids + [-1] * padding_length |
| all_branch_depths = all_branch_depths + [0] * padding_length |
| all_linkage_types = all_linkage_types + [0] * padding_length |
| attention_mask = attention_mask + [0] * padding_length |
| |
| return { |
| 'tokens': all_tokens, |
| 'token_ids': all_token_ids, |
| 'residue_ids': all_residue_ids, |
| 'branch_depths': all_branch_depths, |
| 'linkage_types': all_linkage_types, |
| 'monosaccharide_names': monosaccharide_names, |
| 'num_residues': len(monosaccharide_names), |
| 'length': length, |
| 'attention_mask': attention_mask, |
| 'is_branched': is_branched, |
| 'has_unk_mod': '[UNK_MOD]' in all_tokens[:length] |
| } |
| |
| def batch_tokenize(self, wurcs_list: List[str], max_length: int = 512) -> Dict: |
| """ |
| Tokenize a batch of WURCS strings. |
| |
| Args: |
| wurcs_list: List of WURCS strings |
| max_length: Maximum sequence length |
| |
| Returns: |
| Dict with batched tensors (as lists) |
| """ |
| results = [self.tokenize(w, max_length) for w in wurcs_list] |
| |
| return { |
| 'token_ids': [r['token_ids'] for r in results], |
| 'attention_mask': [r['attention_mask'] for r in results], |
| 'lengths': [r['length'] for r in results] |
| } |
| |
| def decode(self, token_ids: List[int]) -> str: |
| """ |
| Decode token IDs back to tokens (for debugging). |
| |
| Args: |
| token_ids: List of token IDs |
| |
| Returns: |
| Space-separated token string |
| """ |
| tokens = [self.id_to_token.get(tid, '[UNK]') for tid in token_ids] |
| |
| tokens = [t for t in tokens if t != '[PAD]'] |
| return ' '.join(tokens) |
|
|
|
|
| def create_tokenizer(vocab_path: str = None) -> WURCSTokenizer: |
| """ |
| Create a tokenizer with default vocabulary path. |
| |
| Args: |
| vocab_path: Optional path to vocabulary.json |
| |
| Returns: |
| WURCSTokenizer instance |
| """ |
| if vocab_path is None: |
| |
| default_paths = [ |
| Path(__file__).parent.parent.parent / 'data' / 'vocabulary.json', |
| Path(__file__).parent.parent / 'data' / 'vocabulary.json', |
| ] |
| for p in default_paths: |
| if p.exists(): |
| vocab_path = str(p) |
| break |
| else: |
| raise FileNotFoundError("vocabulary.json not found. Please specify path.") |
| |
| return WURCSTokenizer(vocab_path) |
|
|
|
|
| if __name__ == '__main__': |
| |
| import sys |
| |
| |
| vocab_path = Path(__file__).parent.parent.parent / 'data' / 'vocabulary.json' |
| if not vocab_path.exists(): |
| print(f"Vocabulary not found at {vocab_path}") |
| sys.exit(1) |
| |
| tokenizer = WURCSTokenizer(str(vocab_path)) |
| print(f"Loaded tokenizer with {tokenizer.vocab_size} tokens") |
| |
| |
| test_cases = [ |
| "WURCS=2.0/2,2,1/[a212h-1b_1-5][a2211m-1a_1-5]/1-2/a2-b1", |
| "WURCS=2.0/3,3,2/[a2122h-1b_1-5_2*NCC/3=O][a2112h-1a_1-5][a2211m-1a_1-5]/1-2-3/a4-b1_b3-c1", |
| ] |
| |
| for wurcs in test_cases: |
| result = tokenizer.tokenize(wurcs, max_length=64) |
| print(f"\nWURCS: {wurcs[:50]}...") |
| print(f" Length: {result['length']}") |
| print(f" Branched: {result['is_branched']}") |
| print(f" Tokens: {result['tokens'][:15]}...") |
| print(f" Token IDs: {result['token_ids'][:15]}...") |
|
|