| |
| """ |
| Train WURCS-BPE tokenizer on the glycan corpus. |
| |
| This script: |
| 1. Loads WURCS strings from sequences.pkl |
| 2. Trains BPE with specified number of merges |
| 3. Saves vocabulary to data/bpe_vocabulary.json |
| 4. Regenerates sequences.pkl with new tokenization |
| |
| Usage: |
| python train_wurcs_bpe.py --num-merges 500 |
| """ |
|
|
| import argparse |
| import pickle |
| import sys |
| from pathlib import Path |
|
|
| |
| SCRIPT_DIR = Path(__file__).resolve().parent |
| CODE_DIR = SCRIPT_DIR.parent |
| RELEASE_ROOT = CODE_DIR.parent |
| sys.path.insert(0, str(CODE_DIR)) |
|
|
| from model.wurcs_bpe_tokenizer import WURCSBPETokenizer |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train WURCS-BPE tokenizer") |
| parser.add_argument("--sequences", type=str, |
| default=str(RELEASE_ROOT / "data" / "sequences.pkl"), |
| help="Path to sequences.pkl") |
| parser.add_argument("--num-merges", type=int, default=500, |
| help="Number of BPE merge operations") |
| parser.add_argument("--min-frequency", type=int, default=2, |
| help="Minimum frequency for BPE") |
| parser.add_argument("--max-token-length", type=int, default=20, |
| help="Maximum length of a BPE token (default: 20)") |
| parser.add_argument("--output-vocab", type=str, |
| default=str(RELEASE_ROOT / "data" / "bpe_vocabulary.json"), |
| help="Output path for vocabulary") |
| parser.add_argument("--output-sequences", type=str, |
| default=str(RELEASE_ROOT / "data" / "sequences_bpe.pkl"), |
| help="Output path for regenerated sequences") |
| parser.add_argument("--max-length", type=int, default=256, |
| help="Maximum sequence length") |
| args = parser.parse_args() |
| |
| print("="*70) |
| print("TRAINING WURCS-BPE TOKENIZER") |
| print("="*70) |
| |
| |
| print(f"\n[1/4] Loading WURCS strings from {args.sequences}...") |
| with open(args.sequences, 'rb') as f: |
| data = pickle.load(f) |
| |
| wurcs_strings = [entry.get('wurcs', '') for entry in data.values() if entry.get('wurcs')] |
| print(f"Found {len(wurcs_strings)} WURCS strings") |
| |
| |
| print(f"\n[2/4] Training BPE with {args.num_merges} merges...") |
| tokenizer = WURCSBPETokenizer.train_from_corpus( |
| wurcs_strings=wurcs_strings, |
| num_merges=args.num_merges, |
| output_path=args.output_vocab, |
| min_frequency=args.min_frequency, |
| max_token_length=args.max_token_length, |
| ) |
| |
| print(f"\n[3/4] Vocabulary saved to {args.output_vocab}") |
| print(f" Vocab size: {tokenizer.vocab_size}") |
| print(f" Merges: {len(tokenizer.merges)}") |
| |
| |
| print(f"\n[4/4] Regenerating sequences with BPE tokenization...") |
| |
| new_sequences = {} |
| total_lengths = [] |
| |
| |
| |
| |
| |
| wurcs_keys = list(data.keys()) |
| from tqdm import tqdm |
|
|
| for wurcs_key in tqdm(wurcs_keys, desc="Tokenizing"): |
| if wurcs_key not in data: continue |
| |
| entry = data[wurcs_key] |
| wurcs = entry.get('wurcs', wurcs_key) |
| |
| |
| result = tokenizer.tokenize(wurcs, max_length=args.max_length) |
| |
| |
| new_sequences[wurcs_key] = { |
| 'tokens': result['tokens'], |
| 'token_ids': result['token_ids'], |
| 'residue_ids': result['residue_ids'], |
| 'branch_depths': result['branch_depths'], |
| 'linkage_types': result['linkage_types'], |
| 'attention_mask': result['attention_mask'], |
| 'distance_matrix': result['distance_matrix'], |
| 'length': result['length'], |
| 'wurcs': wurcs, |
| 'source': entry.get('source', 'unknown'), |
| 'has_ms': entry.get('has_ms', False), |
| 'has_3d': entry.get('has_3d', False), |
| 'iupac_name': entry.get('iupac_name', '') |
| } |
| |
| total_lengths.append(result['length']) |
| |
| |
| |
|
|
| print(f"Tokenization complete. Saving {len(new_sequences)} items...") |
| |
| |
| with open(args.output_sequences, 'wb') as f: |
| pickle.dump(new_sequences, f) |
| |
| |
| |
| |
| |
| |
| avg_length = sum(total_lengths) / len(total_lengths) |
| max_length = max(total_lengths) |
| min_length = min(total_lengths) |
| |
| print(f"\n--- Tokenization Statistics ---") |
| print(f" Average length: {avg_length:.1f} tokens") |
| print(f" Min length: {min_length}") |
| print(f" Max length: {max_length}") |
| print(f" Sequences: {len(new_sequences)}") |
| |
| |
| print(f"\nSaving to {args.output_sequences}...") |
| with open(args.output_sequences, 'wb') as f: |
| pickle.dump(new_sequences, f) |
| |
| print("\n" + "="*70) |
| print("DONE!") |
| print("="*70) |
| print(f"\nVocabulary: {args.output_vocab}") |
| print(f"Sequences: {args.output_sequences}") |
| print(f"\nTo use in training:") |
| print(f" 1. Update config to use vocab_size={tokenizer.vocab_size}") |
| print(f" 2. Update config to use max_length={args.max_length}") |
| print(f" 3. Point data loader to {args.output_sequences}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|