#!/usr/bin/env python3 """ Train WURCS-BPE tokenizer on the glycan corpus. This script: 1. Loads WURCS strings from sequences.pkl 2. Trains BPE with specified number of merges 3. Saves vocabulary to data/bpe_vocabulary.json 4. Regenerates sequences.pkl with new tokenization Usage: python train_wurcs_bpe.py --num-merges 500 """ import argparse import pickle import sys from pathlib import Path # Add the release code root for package imports. SCRIPT_DIR = Path(__file__).resolve().parent CODE_DIR = SCRIPT_DIR.parent RELEASE_ROOT = CODE_DIR.parent sys.path.insert(0, str(CODE_DIR)) from model.wurcs_bpe_tokenizer import WURCSBPETokenizer def main(): parser = argparse.ArgumentParser(description="Train WURCS-BPE tokenizer") parser.add_argument("--sequences", type=str, default=str(RELEASE_ROOT / "data" / "sequences.pkl"), help="Path to sequences.pkl") parser.add_argument("--num-merges", type=int, default=500, help="Number of BPE merge operations") parser.add_argument("--min-frequency", type=int, default=2, help="Minimum frequency for BPE") parser.add_argument("--max-token-length", type=int, default=20, help="Maximum length of a BPE token (default: 20)") parser.add_argument("--output-vocab", type=str, default=str(RELEASE_ROOT / "data" / "bpe_vocabulary.json"), help="Output path for vocabulary") parser.add_argument("--output-sequences", type=str, default=str(RELEASE_ROOT / "data" / "sequences_bpe.pkl"), help="Output path for regenerated sequences") parser.add_argument("--max-length", type=int, default=256, help="Maximum sequence length") args = parser.parse_args() print("="*70) print("TRAINING WURCS-BPE TOKENIZER") print("="*70) # Load WURCS strings print(f"\n[1/4] Loading WURCS strings from {args.sequences}...") with open(args.sequences, 'rb') as f: data = pickle.load(f) wurcs_strings = [entry.get('wurcs', '') for entry in data.values() if entry.get('wurcs')] print(f"Found {len(wurcs_strings)} WURCS strings") # Train BPE print(f"\n[2/4] Training BPE with {args.num_merges} merges...") tokenizer = WURCSBPETokenizer.train_from_corpus( wurcs_strings=wurcs_strings, num_merges=args.num_merges, output_path=args.output_vocab, min_frequency=args.min_frequency, max_token_length=args.max_token_length, ) print(f"\n[3/4] Vocabulary saved to {args.output_vocab}") print(f" Vocab size: {tokenizer.vocab_size}") print(f" Merges: {len(tokenizer.merges)}") # Regenerate sequences print(f"\n[4/4] Regenerating sequences with BPE tokenization...") new_sequences = {} total_lengths = [] # CLUSTER MODE: Fast Processing (Loads all into RAM) # The cluster has plenty of RAM (128GB+), so we optimize for speed. wurcs_keys = list(data.keys()) from tqdm import tqdm for wurcs_key in tqdm(wurcs_keys, desc="Tokenizing"): if wurcs_key not in data: continue entry = data[wurcs_key] wurcs = entry.get('wurcs', wurcs_key) # Tokenize result = tokenizer.tokenize(wurcs, max_length=args.max_length) # Create final entry new_sequences[wurcs_key] = { 'tokens': result['tokens'], 'token_ids': result['token_ids'], 'residue_ids': result['residue_ids'], 'branch_depths': result['branch_depths'], 'linkage_types': result['linkage_types'], 'attention_mask': result['attention_mask'], 'distance_matrix': result['distance_matrix'], # Topology 'length': result['length'], 'wurcs': wurcs, 'source': entry.get('source', 'unknown'), 'has_ms': entry.get('has_ms', False), 'has_3d': entry.get('has_3d', False), 'iupac_name': entry.get('iupac_name', '') } total_lengths.append(result['length']) # Optional: Delete old data to keep RAM somewhat largely for new data # del data[wurcs_key] print(f"Tokenization complete. Saving {len(new_sequences)} items...") # Single fast save with open(args.output_sequences, 'wb') as f: pickle.dump(new_sequences, f) # Statistics # Statistics # Statistics avg_length = sum(total_lengths) / len(total_lengths) max_length = max(total_lengths) min_length = min(total_lengths) print(f"\n--- Tokenization Statistics ---") print(f" Average length: {avg_length:.1f} tokens") print(f" Min length: {min_length}") print(f" Max length: {max_length}") print(f" Sequences: {len(new_sequences)}") # Save print(f"\nSaving to {args.output_sequences}...") with open(args.output_sequences, 'wb') as f: pickle.dump(new_sequences, f) print("\n" + "="*70) print("DONE!") print("="*70) print(f"\nVocabulary: {args.output_vocab}") print(f"Sequences: {args.output_sequences}") print(f"\nTo use in training:") print(f" 1. Update config to use vocab_size={tokenizer.vocab_size}") print(f" 2. Update config to use max_length={args.max_length}") print(f" 3. Point data loader to {args.output_sequences}") if __name__ == "__main__": main()