supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
5.57 kB
#!/usr/bin/env python3
"""
Train WURCS-BPE tokenizer on the glycan corpus.
This script:
1. Loads WURCS strings from sequences.pkl
2. Trains BPE with specified number of merges
3. Saves vocabulary to data/bpe_vocabulary.json
4. Regenerates sequences.pkl with new tokenization
Usage:
python train_wurcs_bpe.py --num-merges 500
"""
import argparse
import pickle
import sys
from pathlib import Path
# Add the release code root for package imports.
SCRIPT_DIR = Path(__file__).resolve().parent
CODE_DIR = SCRIPT_DIR.parent
RELEASE_ROOT = CODE_DIR.parent
sys.path.insert(0, str(CODE_DIR))
from model.wurcs_bpe_tokenizer import WURCSBPETokenizer
def main():
parser = argparse.ArgumentParser(description="Train WURCS-BPE tokenizer")
parser.add_argument("--sequences", type=str,
default=str(RELEASE_ROOT / "data" / "sequences.pkl"),
help="Path to sequences.pkl")
parser.add_argument("--num-merges", type=int, default=500,
help="Number of BPE merge operations")
parser.add_argument("--min-frequency", type=int, default=2,
help="Minimum frequency for BPE")
parser.add_argument("--max-token-length", type=int, default=20,
help="Maximum length of a BPE token (default: 20)")
parser.add_argument("--output-vocab", type=str,
default=str(RELEASE_ROOT / "data" / "bpe_vocabulary.json"),
help="Output path for vocabulary")
parser.add_argument("--output-sequences", type=str,
default=str(RELEASE_ROOT / "data" / "sequences_bpe.pkl"),
help="Output path for regenerated sequences")
parser.add_argument("--max-length", type=int, default=256,
help="Maximum sequence length")
args = parser.parse_args()
print("="*70)
print("TRAINING WURCS-BPE TOKENIZER")
print("="*70)
# Load WURCS strings
print(f"\n[1/4] Loading WURCS strings from {args.sequences}...")
with open(args.sequences, 'rb') as f:
data = pickle.load(f)
wurcs_strings = [entry.get('wurcs', '') for entry in data.values() if entry.get('wurcs')]
print(f"Found {len(wurcs_strings)} WURCS strings")
# Train BPE
print(f"\n[2/4] Training BPE with {args.num_merges} merges...")
tokenizer = WURCSBPETokenizer.train_from_corpus(
wurcs_strings=wurcs_strings,
num_merges=args.num_merges,
output_path=args.output_vocab,
min_frequency=args.min_frequency,
max_token_length=args.max_token_length,
)
print(f"\n[3/4] Vocabulary saved to {args.output_vocab}")
print(f" Vocab size: {tokenizer.vocab_size}")
print(f" Merges: {len(tokenizer.merges)}")
# Regenerate sequences
print(f"\n[4/4] Regenerating sequences with BPE tokenization...")
new_sequences = {}
total_lengths = []
# CLUSTER MODE: Fast Processing (Loads all into RAM)
# The cluster has plenty of RAM (128GB+), so we optimize for speed.
wurcs_keys = list(data.keys())
from tqdm import tqdm
for wurcs_key in tqdm(wurcs_keys, desc="Tokenizing"):
if wurcs_key not in data: continue
entry = data[wurcs_key]
wurcs = entry.get('wurcs', wurcs_key)
# Tokenize
result = tokenizer.tokenize(wurcs, max_length=args.max_length)
# Create final entry
new_sequences[wurcs_key] = {
'tokens': result['tokens'],
'token_ids': result['token_ids'],
'residue_ids': result['residue_ids'],
'branch_depths': result['branch_depths'],
'linkage_types': result['linkage_types'],
'attention_mask': result['attention_mask'],
'distance_matrix': result['distance_matrix'], # Topology
'length': result['length'],
'wurcs': wurcs,
'source': entry.get('source', 'unknown'),
'has_ms': entry.get('has_ms', False),
'has_3d': entry.get('has_3d', False),
'iupac_name': entry.get('iupac_name', '')
}
total_lengths.append(result['length'])
# Optional: Delete old data to keep RAM somewhat largely for new data
# del data[wurcs_key]
print(f"Tokenization complete. Saving {len(new_sequences)} items...")
# Single fast save
with open(args.output_sequences, 'wb') as f:
pickle.dump(new_sequences, f)
# Statistics
# Statistics
# Statistics
avg_length = sum(total_lengths) / len(total_lengths)
max_length = max(total_lengths)
min_length = min(total_lengths)
print(f"\n--- Tokenization Statistics ---")
print(f" Average length: {avg_length:.1f} tokens")
print(f" Min length: {min_length}")
print(f" Max length: {max_length}")
print(f" Sequences: {len(new_sequences)}")
# Save
print(f"\nSaving to {args.output_sequences}...")
with open(args.output_sequences, 'wb') as f:
pickle.dump(new_sequences, f)
print("\n" + "="*70)
print("DONE!")
print("="*70)
print(f"\nVocabulary: {args.output_vocab}")
print(f"Sequences: {args.output_sequences}")
print(f"\nTo use in training:")
print(f" 1. Update config to use vocab_size={tokenizer.vocab_size}")
print(f" 2. Update config to use max_length={args.max_length}")
print(f" 3. Point data loader to {args.output_sequences}")
if __name__ == "__main__":
main()