supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
23.4 kB
"""
Standalone WURCS Tokenizer for Fine-tuning
This tokenizer matches the v3.1 pre-training tokenization exactly:
- Uses bracket-counting parser (handles '/' in modifications)
- 167 tokens (including 58 whole-token modifications)
- Same special tokens: [PAD], [UNK], [START], [END], [BRANCH_OPEN], [BRANCH_CLOSE], etc.
- Detects and inserts branch tokens based on WURCS topology
Usage:
tokenizer = WURCSTokenizer('data/vocabulary.json')
result = tokenizer.tokenize(wurcs_string)
# result = {'tokens': [...], 'token_ids': [...], 'length': N}
"""
import json
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
def parse_wurcs_sections(wurcs: str) -> List[str]:
"""
Parse WURCS string into sections by counting brackets.
This correctly handles modifications with '/' characters like NCC/3=O.
The simple regex approach fails because '/' appears inside modifications.
Args:
wurcs: WURCS string
Returns:
List of sections [version, counts, residues, topology, linkages]
"""
sections = []
current_section = []
bracket_depth = 0
for char in wurcs:
if char == '[':
bracket_depth += 1
current_section.append(char)
elif char == ']':
bracket_depth -= 1
current_section.append(char)
elif char == '/' and bracket_depth == 0:
# This '/' is a section separator (not inside a modification)
sections.append(''.join(current_section))
current_section = []
else:
current_section.append(char)
# Add last section
if current_section:
sections.append(''.join(current_section))
return sections
def extract_residues_from_wurcs(wurcs: str) -> List[str]:
"""
Extract residue definitions from WURCS string.
Uses bracket-counting parser to correctly handle modifications
containing '/' characters (e.g., 2*NCC/3=O, 4*OSO/3=O/3=O).
Args:
wurcs: WURCS string
Returns:
List of residue strings (e.g., ['a2122h-1a_1-5_2*NCC/3=O', ...])
"""
sections = parse_wurcs_sections(wurcs)
if len(sections) < 3:
return []
# Section 2 contains residue definitions: [type1][type2]...
residues_section = sections[2]
# Find all residues [...]
residues = re.findall(r'\[([^\]]+)\]', residues_section)
return residues
def extract_linkages_from_wurcs(wurcs: str) -> List[str]:
"""
Extract linkage patterns from WURCS string.
Uses bracket-counting parser to correctly handle modifications
containing '/' characters.
Args:
wurcs: WURCS string
Returns:
List of linkage strings (e.g., ['a4-b1', 'b3-c2'])
"""
sections = parse_wurcs_sections(wurcs)
if len(sections) < 5:
return []
# Section 4 contains linkages: a4-b1_b3-c2_...
linkage_section = sections[4]
# Split by underscore
linkages = linkage_section.split('_')
return [l for l in linkages if l]
def detect_branch_points(linkages: List[str]) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
"""
Detect branch points from linkages.
A residue is a branch point if it has >1 outgoing connections.
Args:
linkages: List of linkage strings (e.g., ['a4-b1', 'a3-c2'])
Returns:
Tuple of:
- branch_points: Dict mapping residue letter to list of children (only for >1 children)
- outgoing: Dict mapping all residue letters to their children
"""
outgoing = defaultdict(list)
for link in linkages:
# Parse link like "a3-b1" or "a?-b1"
match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link)
if match:
from_res, from_pos, to_res, to_pos = match.groups()
outgoing[from_res].append(to_res)
# Branch points have >1 children
branch_points = {k: v for k, v in outgoing.items() if len(v) > 1}
return branch_points, dict(outgoing)
def compute_residue_depths(linkages: List[str]) -> Dict[str, int]:
"""
Compute depth of each residue in the glycan tree using BFS from root.
Root is typically 'a' (first residue). Depth 0 = root, 1 = children of root, etc.
Args:
linkages: List of linkage strings (e.g., ['a4-b1', 'b3-c1'])
Returns:
Dict mapping residue letter to its depth in the tree
"""
# Build parent-child graph from linkages
children = defaultdict(list)
all_residues = set()
for link in linkages:
match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link)
if match:
from_res, _, to_res, _ = match.groups()
children[from_res].append(to_res)
all_residues.add(from_res)
all_residues.add(to_res)
# Find root (residue with no parent) - usually 'a'
child_set = set()
for kids in children.values():
child_set.update(kids)
roots = all_residues - child_set
root = min(roots) if roots else 'a' # Default to 'a' if can't determine
# BFS from root to compute depths
depths = {root: 0}
queue = [root]
while queue:
current = queue.pop(0)
for child in children.get(current, []):
if child not in depths:
depths[child] = depths[current] + 1
queue.append(child)
return depths
def parse_linkage_type(link: str) -> int:
"""
Parse a linkage string to extract linkage type ID.
Linkage types based on anomeric configuration and position:
- α1-2, α1-3, α1-4, α1-6, β1-2, β1-3, β1-4, β1-6, unknown
WURCS uses position numbers (1-6) in linkage patterns like 'a4-b1'.
We extract the positions to determine linkage type.
Args:
link: Linkage string like 'a4-b1' or 'a3-b1'
Returns:
Linkage type ID (0-8):
0 = 1-2, 1 = 1-3, 2 = 1-4, 3 = 1-6, 4 = 2-3, 5 = 2-6, 6 = 3-6, 7 = other, 8 = unknown
"""
LINKAGE_TYPES = {
(1, 2): 0, (2, 1): 0,
(1, 3): 1, (3, 1): 1,
(1, 4): 2, (4, 1): 2,
(1, 6): 3, (6, 1): 3,
(2, 3): 4, (3, 2): 4,
(2, 6): 5, (6, 2): 5,
(3, 6): 6, (6, 3): 6,
}
match = re.match(r'([a-z])([?\d]+)-([a-z])([?\d]+)', link)
if match:
_, from_pos, _, to_pos = match.groups()
try:
pos_tuple = (int(from_pos), int(to_pos))
return LINKAGE_TYPES.get(pos_tuple, 7) # 7 = other
except ValueError:
return 8 # Unknown (contains '?')
return 8
class WURCSTokenizer:
"""
WURCS tokenizer matching v3.1 pre-training exactly.
Key features:
- Bracket-counting parser for '/' handling in modifications
- Branch token insertion based on topology
- Whole-token modifications (58 total)
"""
def __init__(self, vocab_path: str):
"""
Initialize tokenizer with vocabulary.
Args:
vocab_path: Path to vocabulary.json (same as used in v3.1 pre-training)
"""
with open(vocab_path, 'r') as f:
self.vocab = json.load(f)
# Build token_to_id mapping
self.token_to_id = {}
# Add special tokens
for token, idx in self.vocab.get('special_tokens', {}).items():
self.token_to_id[token] = idx
# Add skeleton atoms
for token, idx in self.vocab.get('skeleton_atoms', {}).items():
self.token_to_id[token] = idx
# Add linkage atoms
for token, idx in self.vocab.get('linkage_atoms', {}).items():
self.token_to_id[token] = idx
# Add anomeric symbols
for token, idx in self.vocab.get('anomeric_symbols', {}).items():
self.token_to_id[token] = idx
# Add anomeric positions
for token, idx in self.vocab.get('anomeric_positions', {}).items():
self.token_to_id[token] = idx
# Add whole modifications
for token, idx in self.vocab.get('whole_modifications', {}).items():
self.token_to_id[token] = idx
# Build reverse mapping
self.id_to_token = {idx: token for token, idx in self.token_to_id.items()}
# Set special token IDs
self.pad_token_id = self.token_to_id.get('[PAD]', 0)
self.unk_token_id = self.token_to_id.get('[UNK]', 1)
self.start_token_id = self.token_to_id.get('[START]', 2)
self.end_token_id = self.token_to_id.get('[END]', 3)
self.branch_open_id = self.token_to_id.get('[BRANCH_OPEN]', 5)
self.branch_close_id = self.token_to_id.get('[BRANCH_CLOSE]', 6)
self.mod_token_id = self.token_to_id.get('[MOD]', 9)
self.link_token_id = self.token_to_id.get('[LINK]', 7)
self.vocab_size = self.vocab.get('metadata', {}).get('total_vocab_size', len(self.token_to_id))
def tokenize(self, wurcs: str, max_length: int = 512) -> Dict:
"""
Tokenize a WURCS string.
Matches v3.1 pre-training tokenization exactly:
- Bracket-counting parser for '/' handling
- Branch token insertion based on topology
- Whole-token modifications
Args:
wurcs: WURCS string to tokenize
max_length: Maximum sequence length (truncate if longer)
Returns:
Dict with:
- tokens: List of token strings
- token_ids: List of token IDs
- residue_ids: List of residue IDs for each token (-1=special, -2=linkage)
- monosaccharide_names: List of monosaccharide names for each residue
- length: Number of tokens
- attention_mask: 1 for real tokens, 0 for padding
- is_branched: Whether the glycan has branches
"""
all_tokens = []
all_token_ids = []
all_residue_ids = [] # Track which residue each token belongs to
all_branch_depths = [] # NEW: Branch depth for each token (0=root, 1=child, etc.)
all_linkage_types = [] # NEW: Linkage type for each token (0-8)
monosaccharide_names = [] # Names of monosaccharides in order
is_branched = False
residue_depths = {} # Will be computed from linkages
# Common monosaccharide patterns based on WURCS skeleton
MONO_PATTERNS = {
'a2122h': 'Glc', 'a2112h': 'Gal', 'a1221m': 'Fuc', 'a2211m': 'Rha',
'a212h': 'Xyl', 'a21d2h': 'Man', 'a2112m': 'Ara', 'a2d21h': 'Ido',
'axxxxh': 'Hex', 'Aad21122h': 'Neu5Ac', 'Aad21122h-2': 'Neu5Gc',
}
# Start token (residue_id = -1 for special tokens, branch_depth = 0)
all_tokens.append('[START]')
all_token_ids.append(self.start_token_id)
all_residue_ids.append(-1)
all_branch_depths.append(0) # Special tokens at depth 0
all_linkage_types.append(0) # No linkage for special tokens
try:
# Extract residues and linkages
residues = extract_residues_from_wurcs(wurcs)
linkages = extract_linkages_from_wurcs(wurcs)
# Detect branching and compute residue depths
branch_points, outgoing = detect_branch_points(linkages)
is_branched = len(branch_points) > 0
# NEW: Compute depth of each residue in the tree
residue_depths = compute_residue_depths(linkages)
# Get residue order from topology
sections = parse_wurcs_sections(wurcs)
topology = sections[3] if len(sections) > 3 else ""
residue_order = [int(x) for x in topology.split('-') if x.isdigit()]
# Tokenize each residue in order
residue_letters = 'abcdefghijklmnopqrstuvwxyz'
current_residue_id = 0
for i, res_idx in enumerate(residue_order):
if res_idx < 1 or res_idx > len(residues):
continue
residue = residues[res_idx - 1]
res_letter = residue_letters[i] if i < len(residue_letters) else 'z'
# Determine monosaccharide name from skeleton
mono_name = '<UNK>'
parts = residue.split('_')
if parts:
skeleton = parts[0].split('-')[0] if '-' in parts[0] else parts[0]
# Check for modifications (GlcNAc, GalNAc, etc.)
has_nac = any('NCC/3=O' in p or 'NAc' in p for p in parts)
has_s = any('OSO' in p for p in parts)
for pattern, name in MONO_PATTERNS.items():
if skeleton.startswith(pattern) or skeleton == pattern:
mono_name = name
break
# Adjust for modifications
if has_nac and mono_name in ['Glc', 'Gal']:
mono_name = mono_name + 'NAc'
elif has_s:
mono_name = mono_name + 'S'
monosaccharide_names.append(mono_name)
# Check if this is a branch point - add [BRANCH_OPEN]
# Get depth for this residue
res_depth = residue_depths.get(res_letter, 0)
if res_letter in branch_points:
all_tokens.append('[BRANCH_OPEN]')
all_token_ids.append(self.branch_open_id)
all_residue_ids.append(-1) # Branch tokens are special
all_branch_depths.append(res_depth)
all_linkage_types.append(0)
# Parse residue: skeleton-anomer_ring_mods
# Part 0: Skeleton and anomer (e.g., "a2122h-1b")
if parts:
main_part = parts[0]
for char in main_part:
if char in self.token_to_id:
all_tokens.append(char)
all_token_ids.append(self.token_to_id[char])
all_residue_ids.append(current_residue_id)
all_branch_depths.append(res_depth)
all_linkage_types.append(0) # Residue tokens have no linkage type
elif char == '-':
if '-' in self.token_to_id:
all_tokens.append('-')
all_token_ids.append(self.token_to_id['-'])
all_residue_ids.append(current_residue_id)
all_branch_depths.append(res_depth)
all_linkage_types.append(0)
else:
all_tokens.append('[UNK]')
all_token_ids.append(self.unk_token_id)
all_residue_ids.append(current_residue_id)
all_branch_depths.append(res_depth)
all_linkage_types.append(0)
# Parts 2+: Modifications (skip part 1 which is ring closure)
if len(parts) > 2:
for part in parts[2:]:
if '*' in part:
all_tokens.append('[MOD]')
all_token_ids.append(self.mod_token_id)
all_residue_ids.append(current_residue_id)
all_branch_depths.append(res_depth)
all_linkage_types.append(0)
if part in self.token_to_id:
all_tokens.append(part)
all_token_ids.append(self.token_to_id[part])
all_residue_ids.append(current_residue_id)
all_branch_depths.append(res_depth)
all_linkage_types.append(0)
else:
all_tokens.append('[UNK_MOD]')
all_token_ids.append(self.token_to_id.get('[UNK_MOD]', self.unk_token_id))
all_residue_ids.append(current_residue_id)
all_branch_depths.append(res_depth)
all_linkage_types.append(0)
current_residue_id += 1
# Add branch close tokens at the end
branch_opens = all_tokens.count('[BRANCH_OPEN]')
for _ in range(branch_opens):
all_tokens.append('[BRANCH_CLOSE]')
all_token_ids.append(self.branch_close_id)
all_residue_ids.append(-1)
all_branch_depths.append(0) # BRANCH_CLOSE at depth 0 (closing structure)
all_linkage_types.append(0)
# Tokenize linkages (residue_id = -2 for linkage tokens)
for linkage in linkages:
if not linkage:
continue
# Parse linkage type for this linkage
link_type = parse_linkage_type(linkage)
all_tokens.append('[LINK]')
all_token_ids.append(self.link_token_id)
all_residue_ids.append(-2)
all_branch_depths.append(0) # Linkages at depth 0 (structural)
all_linkage_types.append(link_type) # Actual linkage type here!
for char in linkage:
if char in self.token_to_id:
all_tokens.append(char)
all_token_ids.append(self.token_to_id[char])
all_residue_ids.append(-2)
all_branch_depths.append(0)
all_linkage_types.append(link_type)
else:
all_tokens.append('[UNK]')
all_token_ids.append(self.unk_token_id)
all_residue_ids.append(-2)
all_branch_depths.append(0)
all_linkage_types.append(link_type)
except Exception:
pass
# End token
all_tokens.append('[END]')
all_token_ids.append(self.end_token_id)
all_residue_ids.append(-1)
all_branch_depths.append(0)
all_linkage_types.append(0)
# Truncate if necessary
if len(all_token_ids) > max_length:
all_tokens = all_tokens[:max_length-1] + ['[END]']
all_token_ids = all_token_ids[:max_length-1] + [self.end_token_id]
all_residue_ids = all_residue_ids[:max_length-1] + [-1]
all_branch_depths = all_branch_depths[:max_length-1] + [0]
all_linkage_types = all_linkage_types[:max_length-1] + [0]
length = len(all_token_ids)
attention_mask = [1] * length
# Pad to max_length
padding_length = max_length - length
if padding_length > 0:
all_tokens = all_tokens + ['[PAD]'] * padding_length
all_token_ids = all_token_ids + [self.pad_token_id] * padding_length
all_residue_ids = all_residue_ids + [-1] * padding_length
all_branch_depths = all_branch_depths + [0] * padding_length
all_linkage_types = all_linkage_types + [0] * padding_length
attention_mask = attention_mask + [0] * padding_length
return {
'tokens': all_tokens,
'token_ids': all_token_ids,
'residue_ids': all_residue_ids,
'branch_depths': all_branch_depths, # NEW: Branch depth for each token
'linkage_types': all_linkage_types, # NEW: Linkage type for each token
'monosaccharide_names': monosaccharide_names,
'num_residues': len(monosaccharide_names),
'length': length,
'attention_mask': attention_mask,
'is_branched': is_branched,
'has_unk_mod': '[UNK_MOD]' in all_tokens[:length]
}
def batch_tokenize(self, wurcs_list: List[str], max_length: int = 512) -> Dict:
"""
Tokenize a batch of WURCS strings.
Args:
wurcs_list: List of WURCS strings
max_length: Maximum sequence length
Returns:
Dict with batched tensors (as lists)
"""
results = [self.tokenize(w, max_length) for w in wurcs_list]
return {
'token_ids': [r['token_ids'] for r in results],
'attention_mask': [r['attention_mask'] for r in results],
'lengths': [r['length'] for r in results]
}
def decode(self, token_ids: List[int]) -> str:
"""
Decode token IDs back to tokens (for debugging).
Args:
token_ids: List of token IDs
Returns:
Space-separated token string
"""
tokens = [self.id_to_token.get(tid, '[UNK]') for tid in token_ids]
# Remove padding
tokens = [t for t in tokens if t != '[PAD]']
return ' '.join(tokens)
def create_tokenizer(vocab_path: str = None) -> WURCSTokenizer:
"""
Create a tokenizer with default vocabulary path.
Args:
vocab_path: Optional path to vocabulary.json
Returns:
WURCSTokenizer instance
"""
if vocab_path is None:
# Try default locations
default_paths = [
Path(__file__).parent.parent.parent / 'data' / 'vocabulary.json',
Path(__file__).parent.parent / 'data' / 'vocabulary.json',
]
for p in default_paths:
if p.exists():
vocab_path = str(p)
break
else:
raise FileNotFoundError("vocabulary.json not found. Please specify path.")
return WURCSTokenizer(vocab_path)
if __name__ == '__main__':
# Test tokenizer
import sys
# Find vocab
vocab_path = Path(__file__).parent.parent.parent / 'data' / 'vocabulary.json'
if not vocab_path.exists():
print(f"Vocabulary not found at {vocab_path}")
sys.exit(1)
tokenizer = WURCSTokenizer(str(vocab_path))
print(f"Loaded tokenizer with {tokenizer.vocab_size} tokens")
# Test cases
test_cases = [
"WURCS=2.0/2,2,1/[a212h-1b_1-5][a2211m-1a_1-5]/1-2/a2-b1",
"WURCS=2.0/3,3,2/[a2122h-1b_1-5_2*NCC/3=O][a2112h-1a_1-5][a2211m-1a_1-5]/1-2-3/a4-b1_b3-c1",
]
for wurcs in test_cases:
result = tokenizer.tokenize(wurcs, max_length=64)
print(f"\nWURCS: {wurcs[:50]}...")
print(f" Length: {result['length']}")
print(f" Branched: {result['is_branched']}")
print(f" Tokens: {result['tokens'][:15]}...")
print(f" Token IDs: {result['token_ids'][:15]}...")