File size: 5,556 Bytes
1d6f391 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | #!/usr/bin/env python3
"""
Augment pre-training data with missing GlycanML benchmark glycans.
This script adds ~5,971 glycans from the benchmark that are not in the
pre-training data (sequences.pkl), improving coverage from 39% to 100%.
Usage:
python augment_pretraining_data.py
"""
import os
import sys
import pickle
import pandas as pd
from tqdm import tqdm
# Add tokenizer path
sys.path.insert(0, 'bert_training_v4')
from downstream_tasks.utils.tokenizer import WURCSTokenizer
def main():
# Paths
BENCHMARK_PATH = 'bert_training_v4/downstream_tasks/glycan_classification_with_wurcs.csv'
INDEX_PATH = 'data/multimodal_index.csv'
SEQUENCES_PATH = 'data/sequences.pkl'
VOCAB_PATH = 'data/vocabulary.json'
# Output paths
OUTPUT_SEQUENCES = 'data/sequences_augmented.pkl'
OUTPUT_INDEX = 'data/multimodal_index_augmented.csv'
print("="*60)
print("AUGMENTING PRE-TRAINING DATA WITH BENCHMARK GLYCANS")
print("="*60)
# Load existing data
print("\n[1/5] Loading existing data...")
df_bench = pd.read_csv(BENCHMARK_PATH)
df_idx = pd.read_csv(INDEX_PATH, low_memory=False)
with open(SEQUENCES_PATH, 'rb') as f:
sequences = pickle.load(f)
print(f" Benchmark: {len(df_bench)} samples")
print(f" Index: {len(df_idx)} samples")
print(f" Sequences.pkl: {len(sequences)} samples")
# Find missing WURCS
print("\n[2/5] Finding missing WURCS...")
pretrain_wurcs = set(df_idx['wurcs'].dropna())
bench_wurcs = df_bench['wurcs'].dropna().tolist()
# Get unique missing WURCS with their species info
missing_data = {}
for _, row in df_bench.iterrows():
wurcs = row.get('wurcs')
if pd.isna(wurcs) or wurcs in pretrain_wurcs or wurcs in missing_data:
continue
missing_data[wurcs] = {
'species': row.get('species', 'Unknown'),
'kingdom': row.get('kingdom', 'Unknown'),
}
print(f" Missing unique WURCS: {len(missing_data)}")
# Initialize tokenizer
print("\n[3/5] Tokenizing missing WURCS...")
tokenizer = WURCSTokenizer(VOCAB_PATH)
# Process missing WURCS
new_sequences = {}
new_index_rows = []
for i, (wurcs, meta) in enumerate(tqdm(missing_data.items(), desc="Tokenizing")):
try:
# Tokenize
result = tokenizer.tokenize(wurcs, max_length=512)
# Create sequences.pkl entry
key = f"benchmark_{i:05d}"
new_sequences[key] = {
'tokens': None, # Not used for pre-training
'token_ids': result['token_ids'],
'length': len([t for t in result['token_ids'] if t != 0]),
'has_unk_mod': False,
'wurcs': wurcs,
'source': 'glycanml_benchmark',
'residue_ids': result['residue_ids'],
'has_ms': False,
'has_3d': False,
'iupac_name': None,
}
# Create index entry
new_index_rows.append({
'wurcs': wurcs,
'accession': key,
'iupac': None,
'has_sequence': True,
'has_ms': False,
'has_structure': False,
'num_tokens': result['length'],
'num_residues': result['num_residues'],
'has_residue_error': False,
'num_ms_tokens': 0,
'num_atoms': 0,
'num_structural_tokens': 0,
'monosaccharide_names': ','.join(result.get('monosaccharide_names', [])),
})
except Exception as e:
print(f" Warning: Failed to tokenize {wurcs[:50]}... - {e}")
continue
print(f" Successfully tokenized: {len(new_sequences)}")
# Merge with existing data
print("\n[4/5] Merging with existing data...")
# Merge sequences
augmented_sequences = {**sequences, **new_sequences}
print(f" Original sequences: {len(sequences)}")
print(f" New sequences: {len(new_sequences)}")
print(f" Total sequences: {len(augmented_sequences)}")
# Merge index
df_new = pd.DataFrame(new_index_rows)
df_augmented = pd.concat([df_idx, df_new], ignore_index=True)
print(f" Original index: {len(df_idx)}")
print(f" New index entries: {len(df_new)}")
print(f" Total index: {len(df_augmented)}")
# Save augmented data
print("\n[5/5] Saving augmented data...")
with open(OUTPUT_SEQUENCES, 'wb') as f:
pickle.dump(augmented_sequences, f)
print(f" Saved: {OUTPUT_SEQUENCES}")
df_augmented.to_csv(OUTPUT_INDEX, index=False)
print(f" Saved: {OUTPUT_INDEX}")
# Summary
print("\n" + "="*60)
print("AUGMENTATION COMPLETE")
print("="*60)
print(f"Added {len(new_sequences)} glycans from benchmark")
print(f"New overlap: 100% (was 39%)")
print()
print("To use augmented data, rename files:")
print(" mv data/sequences.pkl data/sequences_original.pkl")
print(" mv data/sequences_augmented.pkl data/sequences.pkl")
print(" mv data/multimodal_index.csv data/multimodal_index_original.csv")
print(" mv data/multimodal_index_augmented.csv data/multimodal_index.csv")
print()
print("Then re-run pre-training:")
print(" python training/train_multimodal.py --config model/multimodal_config.yaml --restart")
if __name__ == "__main__":
main()
|