| |
| """ |
| Augment pre-training data with missing GlycanML benchmark glycans. |
| |
| This script adds ~5,971 glycans from the benchmark that are not in the |
| pre-training data (sequences.pkl), improving coverage from 39% to 100%. |
| |
| Usage: |
| python augment_pretraining_data.py |
| """ |
|
|
| import os |
| import sys |
| import pickle |
| import pandas as pd |
| from tqdm import tqdm |
|
|
| |
| sys.path.insert(0, 'bert_training_v4') |
| from downstream_tasks.utils.tokenizer import WURCSTokenizer |
|
|
|
|
| def main(): |
| |
| BENCHMARK_PATH = 'bert_training_v4/downstream_tasks/glycan_classification_with_wurcs.csv' |
| INDEX_PATH = 'data/multimodal_index.csv' |
| SEQUENCES_PATH = 'data/sequences.pkl' |
| VOCAB_PATH = 'data/vocabulary.json' |
| |
| |
| OUTPUT_SEQUENCES = 'data/sequences_augmented.pkl' |
| OUTPUT_INDEX = 'data/multimodal_index_augmented.csv' |
| |
| print("="*60) |
| print("AUGMENTING PRE-TRAINING DATA WITH BENCHMARK GLYCANS") |
| print("="*60) |
| |
| |
| print("\n[1/5] Loading existing data...") |
| df_bench = pd.read_csv(BENCHMARK_PATH) |
| df_idx = pd.read_csv(INDEX_PATH, low_memory=False) |
| |
| with open(SEQUENCES_PATH, 'rb') as f: |
| sequences = pickle.load(f) |
| |
| print(f" Benchmark: {len(df_bench)} samples") |
| print(f" Index: {len(df_idx)} samples") |
| print(f" Sequences.pkl: {len(sequences)} samples") |
| |
| |
| print("\n[2/5] Finding missing WURCS...") |
| pretrain_wurcs = set(df_idx['wurcs'].dropna()) |
| bench_wurcs = df_bench['wurcs'].dropna().tolist() |
| |
| |
| missing_data = {} |
| for _, row in df_bench.iterrows(): |
| wurcs = row.get('wurcs') |
| if pd.isna(wurcs) or wurcs in pretrain_wurcs or wurcs in missing_data: |
| continue |
| missing_data[wurcs] = { |
| 'species': row.get('species', 'Unknown'), |
| 'kingdom': row.get('kingdom', 'Unknown'), |
| } |
| |
| print(f" Missing unique WURCS: {len(missing_data)}") |
| |
| |
| print("\n[3/5] Tokenizing missing WURCS...") |
| tokenizer = WURCSTokenizer(VOCAB_PATH) |
| |
| |
| new_sequences = {} |
| new_index_rows = [] |
| |
| for i, (wurcs, meta) in enumerate(tqdm(missing_data.items(), desc="Tokenizing")): |
| try: |
| |
| result = tokenizer.tokenize(wurcs, max_length=512) |
| |
| |
| key = f"benchmark_{i:05d}" |
| new_sequences[key] = { |
| 'tokens': None, |
| 'token_ids': result['token_ids'], |
| 'length': len([t for t in result['token_ids'] if t != 0]), |
| 'has_unk_mod': False, |
| 'wurcs': wurcs, |
| 'source': 'glycanml_benchmark', |
| 'residue_ids': result['residue_ids'], |
| 'has_ms': False, |
| 'has_3d': False, |
| 'iupac_name': None, |
| } |
| |
| |
| new_index_rows.append({ |
| 'wurcs': wurcs, |
| 'accession': key, |
| 'iupac': None, |
| 'has_sequence': True, |
| 'has_ms': False, |
| 'has_structure': False, |
| 'num_tokens': result['length'], |
| 'num_residues': result['num_residues'], |
| 'has_residue_error': False, |
| 'num_ms_tokens': 0, |
| 'num_atoms': 0, |
| 'num_structural_tokens': 0, |
| 'monosaccharide_names': ','.join(result.get('monosaccharide_names', [])), |
| }) |
| |
| except Exception as e: |
| print(f" Warning: Failed to tokenize {wurcs[:50]}... - {e}") |
| continue |
| |
| print(f" Successfully tokenized: {len(new_sequences)}") |
| |
| |
| print("\n[4/5] Merging with existing data...") |
| |
| |
| augmented_sequences = {**sequences, **new_sequences} |
| print(f" Original sequences: {len(sequences)}") |
| print(f" New sequences: {len(new_sequences)}") |
| print(f" Total sequences: {len(augmented_sequences)}") |
| |
| |
| df_new = pd.DataFrame(new_index_rows) |
| df_augmented = pd.concat([df_idx, df_new], ignore_index=True) |
| print(f" Original index: {len(df_idx)}") |
| print(f" New index entries: {len(df_new)}") |
| print(f" Total index: {len(df_augmented)}") |
| |
| |
| print("\n[5/5] Saving augmented data...") |
| |
| with open(OUTPUT_SEQUENCES, 'wb') as f: |
| pickle.dump(augmented_sequences, f) |
| print(f" Saved: {OUTPUT_SEQUENCES}") |
| |
| df_augmented.to_csv(OUTPUT_INDEX, index=False) |
| print(f" Saved: {OUTPUT_INDEX}") |
| |
| |
| print("\n" + "="*60) |
| print("AUGMENTATION COMPLETE") |
| print("="*60) |
| print(f"Added {len(new_sequences)} glycans from benchmark") |
| print(f"New overlap: 100% (was 39%)") |
| print() |
| print("To use augmented data, rename files:") |
| print(" mv data/sequences.pkl data/sequences_original.pkl") |
| print(" mv data/sequences_augmented.pkl data/sequences.pkl") |
| print(" mv data/multimodal_index.csv data/multimodal_index_original.csv") |
| print(" mv data/multimodal_index_augmented.csv data/multimodal_index.csv") |
| print() |
| print("Then re-run pre-training:") |
| print(" python training/train_multimodal.py --config model/multimodal_config.yaml --restart") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|