bertose-affinose-training-code / code /data_processing /augment_pretraining_data.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
5.56 kB
#!/usr/bin/env python3
"""
Augment pre-training data with missing GlycanML benchmark glycans.
This script adds ~5,971 glycans from the benchmark that are not in the
pre-training data (sequences.pkl), improving coverage from 39% to 100%.
Usage:
python augment_pretraining_data.py
"""
import os
import sys
import pickle
import pandas as pd
from tqdm import tqdm
# Add tokenizer path
sys.path.insert(0, 'bert_training_v4')
from downstream_tasks.utils.tokenizer import WURCSTokenizer
def main():
# Paths
BENCHMARK_PATH = 'bert_training_v4/downstream_tasks/glycan_classification_with_wurcs.csv'
INDEX_PATH = 'data/multimodal_index.csv'
SEQUENCES_PATH = 'data/sequences.pkl'
VOCAB_PATH = 'data/vocabulary.json'
# Output paths
OUTPUT_SEQUENCES = 'data/sequences_augmented.pkl'
OUTPUT_INDEX = 'data/multimodal_index_augmented.csv'
print("="*60)
print("AUGMENTING PRE-TRAINING DATA WITH BENCHMARK GLYCANS")
print("="*60)
# Load existing data
print("\n[1/5] Loading existing data...")
df_bench = pd.read_csv(BENCHMARK_PATH)
df_idx = pd.read_csv(INDEX_PATH, low_memory=False)
with open(SEQUENCES_PATH, 'rb') as f:
sequences = pickle.load(f)
print(f" Benchmark: {len(df_bench)} samples")
print(f" Index: {len(df_idx)} samples")
print(f" Sequences.pkl: {len(sequences)} samples")
# Find missing WURCS
print("\n[2/5] Finding missing WURCS...")
pretrain_wurcs = set(df_idx['wurcs'].dropna())
bench_wurcs = df_bench['wurcs'].dropna().tolist()
# Get unique missing WURCS with their species info
missing_data = {}
for _, row in df_bench.iterrows():
wurcs = row.get('wurcs')
if pd.isna(wurcs) or wurcs in pretrain_wurcs or wurcs in missing_data:
continue
missing_data[wurcs] = {
'species': row.get('species', 'Unknown'),
'kingdom': row.get('kingdom', 'Unknown'),
}
print(f" Missing unique WURCS: {len(missing_data)}")
# Initialize tokenizer
print("\n[3/5] Tokenizing missing WURCS...")
tokenizer = WURCSTokenizer(VOCAB_PATH)
# Process missing WURCS
new_sequences = {}
new_index_rows = []
for i, (wurcs, meta) in enumerate(tqdm(missing_data.items(), desc="Tokenizing")):
try:
# Tokenize
result = tokenizer.tokenize(wurcs, max_length=512)
# Create sequences.pkl entry
key = f"benchmark_{i:05d}"
new_sequences[key] = {
'tokens': None, # Not used for pre-training
'token_ids': result['token_ids'],
'length': len([t for t in result['token_ids'] if t != 0]),
'has_unk_mod': False,
'wurcs': wurcs,
'source': 'glycanml_benchmark',
'residue_ids': result['residue_ids'],
'has_ms': False,
'has_3d': False,
'iupac_name': None,
}
# Create index entry
new_index_rows.append({
'wurcs': wurcs,
'accession': key,
'iupac': None,
'has_sequence': True,
'has_ms': False,
'has_structure': False,
'num_tokens': result['length'],
'num_residues': result['num_residues'],
'has_residue_error': False,
'num_ms_tokens': 0,
'num_atoms': 0,
'num_structural_tokens': 0,
'monosaccharide_names': ','.join(result.get('monosaccharide_names', [])),
})
except Exception as e:
print(f" Warning: Failed to tokenize {wurcs[:50]}... - {e}")
continue
print(f" Successfully tokenized: {len(new_sequences)}")
# Merge with existing data
print("\n[4/5] Merging with existing data...")
# Merge sequences
augmented_sequences = {**sequences, **new_sequences}
print(f" Original sequences: {len(sequences)}")
print(f" New sequences: {len(new_sequences)}")
print(f" Total sequences: {len(augmented_sequences)}")
# Merge index
df_new = pd.DataFrame(new_index_rows)
df_augmented = pd.concat([df_idx, df_new], ignore_index=True)
print(f" Original index: {len(df_idx)}")
print(f" New index entries: {len(df_new)}")
print(f" Total index: {len(df_augmented)}")
# Save augmented data
print("\n[5/5] Saving augmented data...")
with open(OUTPUT_SEQUENCES, 'wb') as f:
pickle.dump(augmented_sequences, f)
print(f" Saved: {OUTPUT_SEQUENCES}")
df_augmented.to_csv(OUTPUT_INDEX, index=False)
print(f" Saved: {OUTPUT_INDEX}")
# Summary
print("\n" + "="*60)
print("AUGMENTATION COMPLETE")
print("="*60)
print(f"Added {len(new_sequences)} glycans from benchmark")
print(f"New overlap: 100% (was 39%)")
print()
print("To use augmented data, rename files:")
print(" mv data/sequences.pkl data/sequences_original.pkl")
print(" mv data/sequences_augmented.pkl data/sequences.pkl")
print(" mv data/multimodal_index.csv data/multimodal_index_original.csv")
print(" mv data/multimodal_index_augmented.csv data/multimodal_index.csv")
print()
print("Then re-run pre-training:")
print(" python training/train_multimodal.py --config model/multimodal_config.yaml --restart")
if __name__ == "__main__":
main()