bertose-iar-resolver / src /wurcs_bpe_tokenizer.py
supanthadey1's picture
Add files using upload-large-folder tool
9b82aa2 verified
#!/usr/bin/env python3
"""
WURCS-BPE Tokenizer
A hybrid tokenizer that learns semantic subwords from WURCS while preserving
the ability to handle rare/novel glycan structures character-by-character.
Key features:
1. Pre-tokenization: Split WURCS into semantic units (residues, linkages, mods)
2. BPE: Learn subword merges from corpus
3. Character fallback: Handle novel structures
4. Tree embeddings: Preserve branch_depth and linkage_type per token
Usage:
# Train BPE on corpus
tokenizer = WURCSBPETokenizer.train_from_corpus(
wurcs_strings,
num_merges=500,
output_path="bpe_vocabulary.json"
)
# Tokenize
result = tokenizer.tokenize(wurcs_string)
"""
import json
import re
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Set
import pickle
class WURCSPreTokenizer:
"""
Pre-tokenize WURCS into semantic units before BPE.
WURCS format: WURCS=2.0/count/[residues]/indices/linkages
We split into:
- Residues: [a2122h-1b_1-5_2*NCC/3=O] -> one unit per []
- Linkages: a4-b1 -> one unit per linkage
- Special markers: [BRANCH_OPEN], [BRANCH_CLOSE], etc.
"""
# Residue patterns for common monosaccharides
RESIDUE_PATTERN = re.compile(r'\[([^\]]+)\]')
LINKAGE_PATTERN = re.compile(r'([a-z])(\d+|\?)-([a-z])(\d+|\?)')
def __init__(self):
self.special_tokens = {
'[PAD]': 0,
'[UNK]': 1,
'[START]': 2,
'[END]': 3,
'[MASK]': 4,
'[BRANCH_OPEN]': 5,
'[BRANCH_CLOSE]': 6,
'[LINK]': 7,
'[MOD]': 8,
'[RESIDUE_ERROR]': 9,
}
def pre_tokenize(self, wurcs: str) -> List[Dict]:
"""
Pre-tokenize WURCS into semantic units.
Returns list of dicts with:
- text: The unit text
- type: 'special', 'residue', 'linkage', 'mod', 'index'
- residue_id: Which residue this belongs to (-1 for special, -2 for linkage)
- branch_depth: Tree depth (computed later)
"""
units = []
# Add start token
units.append({
'text': '[START]',
'type': 'special',
'residue_id': -1,
'branch_depth': 0,
'linkage_type': 0,
})
# Parse WURCS sections
if not wurcs.startswith('WURCS='):
units.append({'text': '[RESIDUE_ERROR]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0})
units.append({'text': '[END]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0})
return units
try:
parts = self._split_wurcs_sections(wurcs)
if len(parts) < 4:
return [{'text': '[ERROR]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0}]
# parts: WURCS=2.0/3,3,2/[a2122h-1b_1-5][a2122h-1a_1-5][a1122h-1b_1-5]/1-2-3-1/a4-b1_b3-c1_c4-d1
# section 2: residue definitions
# section 3: indices
# section 4: linkages (optional)
version = parts[0] # WURCS=2.0
counts = parts[1] # residue_count,node_count,link_count
residue_defs = parts[2] # [res1][res2]...
indices = parts[3] # 1-2-3-1
linkages = parts[4] if len(parts) > 4 else "" # a4-b1_b3-c1
# Parse residue definitions
residue_list = self.RESIDUE_PATTERN.findall(residue_defs)
# Parse linkages to compute branch structure
linkage_list = linkages.split('_') if linkages else []
branch_points, residue_depths, linkage_types_map, adj = self._analyze_tree_structure(linkage_list, num_residues=len(residue_list))
# Compute distance matrix and cache it based on the linkage string (structure)
# This is the most expensive part, so we cache it
if not hasattr(self, '_dist_cache'): self._dist_cache = {}
if linkages not in self._dist_cache:
self._dist_cache[linkages] = self._compute_distance_matrix(adj, len(residue_list))
dist_matrix_raw = self._dist_cache[linkages]
# Parse indices to map positions to residue definitions
index_list = indices.split('-') if indices else []
# Process each residue instance
residue_letter = ord('a')
for idx, res_idx in enumerate(index_list):
current_residue_id = idx
res_letter = chr(residue_letter + idx)
# Check if this is a branch point - add branch marker before
if res_letter in branch_points and branch_points[res_letter] > 0:
for _ in range(branch_points[res_letter]):
units.append({
'text': '[BRANCH_OPEN]',
'type': 'special',
'residue_id': -1,
'branch_depth': residue_depths.get(res_letter, 0),
'linkage_type': 0,
})
# Get residue definition
try:
res_def_idx = int(res_idx) - 1 # 1-indexed to 0-indexed
res_def = residue_list[res_def_idx] if res_def_idx < len(residue_list) else ""
except (ValueError, IndexError):
res_def = ""
# Split residue into base and modifications
res_parts = res_def.split('_')
base = res_parts[0] if res_parts else res_def
mods = res_parts[1:] if len(res_parts) > 1 else []
# Add residue base as a single unit
depth = residue_depths.get(res_letter, 0)
units.append({
'text': base,
'type': 'residue',
'residue_id': current_residue_id,
'branch_depth': depth,
'linkage_type': 0,
})
# Add modifications
for mod in mods:
units.append({
'text': mod,
'type': 'mod',
'residue_id': current_residue_id,
'branch_depth': depth,
'linkage_type': 0,
})
# Store distance matrix in units for easy access in tokenizer
if units:
# Find first residue unit or just use START
units[0]['distance_matrix'] = dist_matrix_raw
# Add linkages
for link in linkage_list:
if not link:
continue
# Parse linkage type
lt = self._parse_linkage_type(link)
units.append({
'text': link,
'type': 'linkage',
'residue_id': -2,
'branch_depth': 0,
'linkage_type': lt,
})
except Exception:
# Fallback for truly broken WURCS
pass
# Add end token
units.append({
'text': '[END]',
'type': 'special',
'residue_id': -1,
'branch_depth': 0,
'linkage_type': 0,
})
return units
def _split_wurcs_sections(self, wurcs: str) -> List[str]:
"""Split WURCS string into sections, handling nested brackets."""
# Remove WURCS= prefix
if wurcs.startswith('WURCS='):
wurcs = wurcs[6:]
sections = []
current = ""
bracket_depth = 0
for char in wurcs:
if char == '[':
bracket_depth += 1
current += char
elif char == ']':
bracket_depth -= 1
current += char
elif char == '/' and bracket_depth == 0:
sections.append(current)
current = ""
else:
current += char
if current:
sections.append(current)
return sections
def _analyze_tree_structure(self, linkages: List[str], num_residues: int) -> Tuple[Dict, Dict, Dict, Dict]:
"""Analyze linkages to compute branch points and residue depths."""
branch_points = defaultdict(int) # residue -> number of children
children = defaultdict(list)
all_residues = set()
linkage_types = {}
for link in linkages:
match = self.LINKAGE_PATTERN.match(link)
if match:
from_res, from_pos, to_res, to_pos = match.groups()
children[from_res].append(to_res)
all_residues.add(from_res)
all_residues.add(to_res)
# Store linkage type
linkage_types[link] = self._parse_linkage_type(link)
# Build adjacency list for BFS
adj = defaultdict(list)
for link in linkages:
match = self.LINKAGE_PATTERN.match(link)
if match:
u = ord(match.group(1)) - ord('a')
v = ord(match.group(3)) - ord('a')
if 0 <= u < num_residues and 0 <= v < num_residues:
adj[u].append(v)
adj[v].append(u)
# Find branch points (residues with >1 child)
for res, kids in children.items():
if len(kids) > 1:
branch_points[res] = len(kids) - 1 # Number of extra branches
# Compute depths using BFS
# Find root (residue with no parent)
child_set = set()
for kids in children.values():
child_set.update(kids)
roots = all_residues - child_set
root = min(roots) if roots else 'a'
depths = {root: 0}
queue = [root]
while queue:
current = queue.pop(0)
for child in children.get(current, []):
if child not in depths:
depths[child] = depths[current] + 1
queue.append(child)
return branch_points, depths, linkage_types, adj
def _compute_distance_matrix(self, adj: Dict[int, List[int]], num_residues: int) -> List[List[int]]:
"""
Compute shortest path distance (number of bonds) between all residue pairs using BFS.
"""
if num_residues == 0:
return []
dist_matrix = [[-1] * num_residues for _ in range(num_residues)]
for i in range(num_residues):
dist_matrix[i][i] = 0
queue = [(i, 0)]
visited = {i}
while queue:
curr, d = queue.pop(0)
dist_matrix[i][curr] = d
for neighbor in adj[curr]:
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, d + 1))
return dist_matrix
def _compute_distance_matrix_OLD(self, linkages: List[str], num_residues: int) -> List[List[int]]:
"""
Compute shortest path distance (number of bonds) between all residue pairs.
Returns a symmetric N x N matrix where N is num_residues.
Values are integers (number of steps). 0 on diagonal. -1 if unreachable (shouldn't happen in single tree).
"""
if num_residues == 0:
return []
# Initialize adjacency list
adj = defaultdict(list)
for link in linkages:
match = self.LINKAGE_PATTERN.match(link)
if match:
# WURCS indices are 1-based letters (a=1, b=2...)
from_res_char, _, to_res_char, _ = match.groups()
# Convert char to 0-based index
u = ord(from_res_char) - ord('a')
v = ord(to_res_char) - ord('a')
# Undirected graph for structural distance
if 0 <= u < num_residues and 0 <= v < num_residues:
adj[u].append(v)
adj[v].append(u)
# Compute All-Pairs Shortest Path (BFS from each node is fine for small N)
# Glycans are small (N ~ 5-20 usually), so O(N^2) BFS is cheap.
dist_matrix = [[-1] * num_residues for _ in range(num_residues)]
for i in range(num_residues):
dist_matrix[i][i] = 0
queue = [(i, 0)]
visited = {i}
while queue:
curr, d = queue.pop(0)
dist_matrix[i][curr] = d
for neighbor in adj[curr]:
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, d + 1))
return dist_matrix
def _parse_linkage_type(self, link: str) -> int:
"""Parse linkage string to get type ID."""
LINKAGE_TYPES = {
(1, 2): 0, (2, 1): 0,
(1, 3): 1, (3, 1): 1,
(1, 4): 2, (4, 1): 2,
(1, 6): 3, (6, 1): 3,
(2, 3): 4, (3, 2): 4,
(2, 6): 5, (6, 2): 5,
(3, 6): 6, (6, 3): 6,
}
match = self.LINKAGE_PATTERN.match(link)
if match:
_, from_pos, _, to_pos = match.groups()
try:
pos_tuple = (int(from_pos), int(to_pos))
return LINKAGE_TYPES.get(pos_tuple, 7)
except ValueError:
return 8 # Unknown
return 8
class WURCSBPETokenizer:
"""
BPE tokenizer for WURCS with tree-aware embeddings.
"""
def __init__(self, vocab_path: Optional[str] = None):
self.pre_tokenizer = WURCSPreTokenizer()
# Special tokens (fixed)
self.special_tokens = self.pre_tokenizer.special_tokens
# BPE vocabulary
self.token_to_id: Dict[str, int] = {}
self.id_to_token: Dict[int, str] = {}
self.merges: List[Tuple[str, str]] = []
if vocab_path:
self.load_vocab(vocab_path)
else:
# Initialize with special tokens only
self.token_to_id = dict(self.special_tokens)
self.id_to_token = {v: k for k, v in self.token_to_id.items()}
@classmethod
def train_from_corpus(
cls,
wurcs_strings: List[str],
num_merges: int = 500,
output_path: Optional[str] = None,
min_frequency: int = 2,
max_token_length: Optional[int] = None,
) -> 'WURCSBPETokenizer':
"""
Train BPE on a corpus of WURCS strings.
Args:
wurcs_strings: List of WURCS strings
num_merges: Number of BPE merge operations
output_path: Optional path to save vocabulary
min_frequency: Minimum frequency for a token to be kept
max_token_length: Maximum length of a merged token (None = no limit)
Returns:
Trained tokenizer
"""
tokenizer = cls()
pre_tok = WURCSPreTokenizer()
print(f"Training BPE on {len(wurcs_strings)} WURCS strings...")
# Step 1: Pre-tokenize all strings to get semantic units
all_units = []
for wurcs in wurcs_strings:
units = pre_tok.pre_tokenize(wurcs)
for unit in units:
if unit['type'] != 'special':
all_units.append(unit['text'])
# Step 2: Count unit frequencies
unit_counts = Counter(all_units)
print(f"Found {len(unit_counts)} unique units")
# Step 3: Initialize vocabulary with characters from all units
char_vocab = set()
for unit in unit_counts:
for char in unit:
char_vocab.add(char)
# Build initial vocab: special tokens + characters
vocab_id = len(tokenizer.special_tokens)
for char in sorted(char_vocab):
tokenizer.token_to_id[char] = vocab_id
tokenizer.id_to_token[vocab_id] = char
vocab_id += 1
print(f"Initial vocab size: {vocab_id} (special + characters)")
# Step 4: Convert units to character sequences
word_freqs = {}
for unit, count in unit_counts.items():
if count >= min_frequency:
# Split into characters with space separator
chars = tuple(unit)
word_freqs[chars] = count
# Step 5: BPE merging
merges = []
for merge_idx in range(num_merges):
# Count pairs
pair_counts = Counter()
for word, freq in word_freqs.items():
for i in range(len(word) - 1):
pair = (word[i], word[i + 1])
pair_counts[pair] += freq
if not pair_counts:
break
# Find most frequent pair
best_pair = pair_counts.most_common(1)[0][0]
best_count = pair_counts[best_pair]
if best_count < min_frequency:
break
# Merge pair
new_token = best_pair[0] + best_pair[1]
# Check length constraint
if max_token_length and len(new_token) > max_token_length:
# remove this pair from consideration for this iteration and future?
# Actually, skipping it here is tricky because we need to ignore it in pair_counts next time
# Simpler: Just skip adding it to merges and modify word_freqs?
# No, if we don't merge, we just continue to the next best pair in THIS iteration.
# But pair_counts is already computed.
# We need to loop until we find a valid pair or run out
# In this simple implementation, let's just skip this merge efficiently
# We need to find the NEXT most common pair.
# Re-do finding best pair loop
found_valid_pair = False
for pair, count in pair_counts.most_common():
token_candidate = pair[0] + pair[1]
if max_token_length and len(token_candidate) > max_token_length:
continue # Skip too long
if count < min_frequency:
break # Stop if frequency too low
# Found valid pair
best_pair = pair
best_count = count
new_token = token_candidate
found_valid_pair = True
break
if not found_valid_pair:
print(f" Stopping early: No more pairs satisfy max_token_length={max_token_length}")
break
# Final check before merging (in case we didn't enter the if block but updated vars)
# Actually the logic above handles it. If we entered the block, we either found a new best_pair or broke.
merges.append(best_pair)
# Add to vocab
tokenizer.token_to_id[new_token] = vocab_id
tokenizer.id_to_token[vocab_id] = new_token
vocab_id += 1
# Update word_freqs
new_word_freqs = {}
for word, freq in word_freqs.items():
new_word = []
i = 0
while i < len(word):
if i < len(word) - 1 and word[i] == best_pair[0] and word[i + 1] == best_pair[1]:
new_word.append(new_token)
i += 2
else:
new_word.append(word[i])
i += 1
new_word_freqs[tuple(new_word)] = freq
word_freqs = new_word_freqs
if (merge_idx + 1) % 100 == 0:
print(f" Merge {merge_idx + 1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' -> '{new_token}' (count={best_count})")
tokenizer.merges = merges
print(f"Final vocab size: {len(tokenizer.token_to_id)}")
# Save if requested
if output_path:
tokenizer.save_vocab(output_path)
return tokenizer
def apply_bpe(self, text: str) -> List[str]:
"""Apply BPE merges to a text string."""
if text in self.token_to_id:
return [text]
# Split into characters
tokens = list(text)
# Apply merges
for pair in self.merges:
new_tokens = []
i = 0
while i < len(tokens):
if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
new_tokens.append(pair[0] + pair[1])
i += 2
else:
new_tokens.append(tokens[i])
i += 1
tokens = new_tokens
return tokens
def tokenize(self, wurcs: str, max_length: int = 256) -> Dict:
"""
Tokenize a WURCS string.
Returns:
Dict with:
- tokens: List of token strings
- token_ids: List of token IDs
- residue_ids: List of residue IDs
- branch_depths: List of branch depths
- linkage_types: List of linkage types
- attention_mask: Attention mask
"""
# Pre-tokenize
units = self.pre_tokenizer.pre_tokenize(wurcs)
tokens = []
token_ids = []
residue_ids = []
branch_depths = []
linkage_types = []
for unit in units:
if unit['type'] == 'special':
# Special tokens stay as-is
tok = unit['text']
tokens.append(tok)
token_ids.append(self.token_to_id.get(tok, self.token_to_id['[UNK]']))
residue_ids.append(unit['residue_id'])
branch_depths.append(unit['branch_depth'])
linkage_types.append(unit['linkage_type'])
else:
# Apply BPE to this unit
bpe_tokens = self.apply_bpe(unit['text'])
for tok in bpe_tokens:
tokens.append(tok)
token_ids.append(self.token_to_id.get(tok, self.token_to_id['[UNK]']))
residue_ids.append(unit['residue_id'])
branch_depths.append(unit['branch_depth'])
linkage_types.append(unit['linkage_type'])
# Truncate if needed
if len(tokens) > max_length:
tokens = tokens[:max_length - 1] + ['[END]']
token_ids = token_ids[:max_length - 1] + [self.token_to_id['[END]']]
residue_ids = residue_ids[:max_length - 1] + [-1]
branch_depths = branch_depths[:max_length - 1] + [0]
linkage_types = linkage_types[:max_length - 1] + [0]
# Create attention mask and pad
length = len(tokens)
attention_mask = [1] * length
while len(tokens) < max_length:
tokens.append('[PAD]')
token_ids.append(self.token_to_id['[PAD]'])
residue_ids.append(-1)
branch_depths.append(0)
linkage_types.append(0)
attention_mask.append(0)
# Pre-tokenize
units = self.pre_tokenizer.pre_tokenize(wurcs)
# Extract distance matrix from pre-tokenizer result
dist_matrix_raw = units[0].get('distance_matrix', [])
num_residues = len(dist_matrix_raw)
# Map token-to-token distances using residue_ids
# token_i is associated with residue_ids[i].
# residue_ids[i] is index into dist_matrix_raw.
# If residue_ids[i] == -1 (special), distance is undefined (use -1 or 999)
# Use UNPADDED length for distance matrix to save massive memory
# distance_matrix will be e.g. 20x20, while tokens are padded to 256
token_len = length
distance_matrix = [[-1] * token_len for _ in range(token_len)]
for i in range(token_len):
for j in range(token_len):
r_i = residue_ids[i]
r_j = residue_ids[j]
if r_i >= 0 and r_j >= 0 and r_i < num_residues and r_j < num_residues:
distance_matrix[i][j] = dist_matrix_raw[r_i][r_j]
else:
distance_matrix[i][j] = -1 # Special/Padding
# MEMORY OPTIMIZATION: Do NOT pad matrix here.
# Pad on-the-fly in Dataset class instead.
# This saves massive memory (0.2GB vs 66GB).
return {
'tokens': tokens,
'token_ids': token_ids,
'residue_ids': residue_ids,
'branch_depths': branch_depths,
'linkage_types': linkage_types,
'attention_mask': attention_mask,
'distance_matrix': distance_matrix, # New Output
'length': length,
}
def save_vocab(self, path: str):
"""Save vocabulary to JSON file."""
data = {
'special_tokens': self.special_tokens,
'token_to_id': self.token_to_id,
'merges': [list(m) for m in self.merges],
'metadata': {
'vocab_size': len(self.token_to_id),
'num_merges': len(self.merges),
}
}
with open(path, 'w') as f:
json.dump(data, f, indent=2)
print(f"Saved vocabulary to {path}")
def load_vocab(self, path: str):
"""Load vocabulary from JSON file."""
with open(path, 'r') as f:
data = json.load(f)
self.special_tokens = data['special_tokens']
self.token_to_id = data['token_to_id']
self.id_to_token = {int(v): k for k, v in self.token_to_id.items()}
self.merges = [tuple(m) for m in data['merges']]
print(f"Loaded vocabulary with {len(self.token_to_id)} tokens")
@property
def vocab_size(self) -> int:
return len(self.token_to_id)
# ============================================================================
# Testing
# ============================================================================
if __name__ == '__main__':
# Test pre-tokenizer
print("="*80)
print("Testing WURCSPreTokenizer")
print("="*80)
pre_tok = WURCSPreTokenizer()
test_wurcs = [
"WURCS=2.0/2,2,1/[a2122h-1b_1-5][a2211m-1a_1-5]/1-2/a4-b1",
"WURCS=2.0/3,3,2/[a2122h-1b_1-5_2*NCC/3=O][a2112h-1a_1-5][a2211m-1a_1-5]/1-2-3/a4-b1_b3-c1",
]
for wurcs in test_wurcs:
print(f"\nWURCS: {wurcs[:60]}...")
units = pre_tok.pre_tokenize(wurcs)
print(f"Units ({len(units)}):")
for u in units[:10]:
print(f" {u['type']:10} | res={u['residue_id']:2} | depth={u['branch_depth']} | {u['text']}")
if len(units) > 10:
print(f" ... and {len(units) - 10} more")