#!/usr/bin/env python3 """ Augment pre-training data with missing GlycanML benchmark glycans. This script adds ~5,971 glycans from the benchmark that are not in the pre-training data (sequences.pkl), improving coverage from 39% to 100%. Usage: python augment_pretraining_data.py """ import os import sys import pickle import pandas as pd from tqdm import tqdm # Add tokenizer path sys.path.insert(0, 'bert_training_v4') from downstream_tasks.utils.tokenizer import WURCSTokenizer def main(): # Paths BENCHMARK_PATH = 'bert_training_v4/downstream_tasks/glycan_classification_with_wurcs.csv' INDEX_PATH = 'data/multimodal_index.csv' SEQUENCES_PATH = 'data/sequences.pkl' VOCAB_PATH = 'data/vocabulary.json' # Output paths OUTPUT_SEQUENCES = 'data/sequences_augmented.pkl' OUTPUT_INDEX = 'data/multimodal_index_augmented.csv' print("="*60) print("AUGMENTING PRE-TRAINING DATA WITH BENCHMARK GLYCANS") print("="*60) # Load existing data print("\n[1/5] Loading existing data...") df_bench = pd.read_csv(BENCHMARK_PATH) df_idx = pd.read_csv(INDEX_PATH, low_memory=False) with open(SEQUENCES_PATH, 'rb') as f: sequences = pickle.load(f) print(f" Benchmark: {len(df_bench)} samples") print(f" Index: {len(df_idx)} samples") print(f" Sequences.pkl: {len(sequences)} samples") # Find missing WURCS print("\n[2/5] Finding missing WURCS...") pretrain_wurcs = set(df_idx['wurcs'].dropna()) bench_wurcs = df_bench['wurcs'].dropna().tolist() # Get unique missing WURCS with their species info missing_data = {} for _, row in df_bench.iterrows(): wurcs = row.get('wurcs') if pd.isna(wurcs) or wurcs in pretrain_wurcs or wurcs in missing_data: continue missing_data[wurcs] = { 'species': row.get('species', 'Unknown'), 'kingdom': row.get('kingdom', 'Unknown'), } print(f" Missing unique WURCS: {len(missing_data)}") # Initialize tokenizer print("\n[3/5] Tokenizing missing WURCS...") tokenizer = WURCSTokenizer(VOCAB_PATH) # Process missing WURCS new_sequences = {} new_index_rows = [] for i, (wurcs, meta) in enumerate(tqdm(missing_data.items(), desc="Tokenizing")): try: # Tokenize result = tokenizer.tokenize(wurcs, max_length=512) # Create sequences.pkl entry key = f"benchmark_{i:05d}" new_sequences[key] = { 'tokens': None, # Not used for pre-training 'token_ids': result['token_ids'], 'length': len([t for t in result['token_ids'] if t != 0]), 'has_unk_mod': False, 'wurcs': wurcs, 'source': 'glycanml_benchmark', 'residue_ids': result['residue_ids'], 'has_ms': False, 'has_3d': False, 'iupac_name': None, } # Create index entry new_index_rows.append({ 'wurcs': wurcs, 'accession': key, 'iupac': None, 'has_sequence': True, 'has_ms': False, 'has_structure': False, 'num_tokens': result['length'], 'num_residues': result['num_residues'], 'has_residue_error': False, 'num_ms_tokens': 0, 'num_atoms': 0, 'num_structural_tokens': 0, 'monosaccharide_names': ','.join(result.get('monosaccharide_names', [])), }) except Exception as e: print(f" Warning: Failed to tokenize {wurcs[:50]}... - {e}") continue print(f" Successfully tokenized: {len(new_sequences)}") # Merge with existing data print("\n[4/5] Merging with existing data...") # Merge sequences augmented_sequences = {**sequences, **new_sequences} print(f" Original sequences: {len(sequences)}") print(f" New sequences: {len(new_sequences)}") print(f" Total sequences: {len(augmented_sequences)}") # Merge index df_new = pd.DataFrame(new_index_rows) df_augmented = pd.concat([df_idx, df_new], ignore_index=True) print(f" Original index: {len(df_idx)}") print(f" New index entries: {len(df_new)}") print(f" Total index: {len(df_augmented)}") # Save augmented data print("\n[5/5] Saving augmented data...") with open(OUTPUT_SEQUENCES, 'wb') as f: pickle.dump(augmented_sequences, f) print(f" Saved: {OUTPUT_SEQUENCES}") df_augmented.to_csv(OUTPUT_INDEX, index=False) print(f" Saved: {OUTPUT_INDEX}") # Summary print("\n" + "="*60) print("AUGMENTATION COMPLETE") print("="*60) print(f"Added {len(new_sequences)} glycans from benchmark") print(f"New overlap: 100% (was 39%)") print() print("To use augmented data, rename files:") print(" mv data/sequences.pkl data/sequences_original.pkl") print(" mv data/sequences_augmented.pkl data/sequences.pkl") print(" mv data/multimodal_index.csv data/multimodal_index_original.csv") print(" mv data/multimodal_index_augmented.csv data/multimodal_index.csv") print() print("Then re-run pre-training:") print(" python training/train_multimodal.py --config model/multimodal_config.yaml --restart") if __name__ == "__main__": main()