File size: 5,556 Bytes
1d6f391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python3
"""
Augment pre-training data with missing GlycanML benchmark glycans.

This script adds ~5,971 glycans from the benchmark that are not in the 
pre-training data (sequences.pkl), improving coverage from 39% to 100%.

Usage:
    python augment_pretraining_data.py
"""

import os
import sys
import pickle
import pandas as pd
from tqdm import tqdm

# Add tokenizer path
sys.path.insert(0, 'bert_training_v4')
from downstream_tasks.utils.tokenizer import WURCSTokenizer


def main():
    # Paths
    BENCHMARK_PATH = 'bert_training_v4/downstream_tasks/glycan_classification_with_wurcs.csv'
    INDEX_PATH = 'data/multimodal_index.csv'
    SEQUENCES_PATH = 'data/sequences.pkl'
    VOCAB_PATH = 'data/vocabulary.json'
    
    # Output paths
    OUTPUT_SEQUENCES = 'data/sequences_augmented.pkl'
    OUTPUT_INDEX = 'data/multimodal_index_augmented.csv'
    
    print("="*60)
    print("AUGMENTING PRE-TRAINING DATA WITH BENCHMARK GLYCANS")
    print("="*60)
    
    # Load existing data
    print("\n[1/5] Loading existing data...")
    df_bench = pd.read_csv(BENCHMARK_PATH)
    df_idx = pd.read_csv(INDEX_PATH, low_memory=False)
    
    with open(SEQUENCES_PATH, 'rb') as f:
        sequences = pickle.load(f)
    
    print(f"  Benchmark: {len(df_bench)} samples")
    print(f"  Index: {len(df_idx)} samples")
    print(f"  Sequences.pkl: {len(sequences)} samples")
    
    # Find missing WURCS
    print("\n[2/5] Finding missing WURCS...")
    pretrain_wurcs = set(df_idx['wurcs'].dropna())
    bench_wurcs = df_bench['wurcs'].dropna().tolist()
    
    # Get unique missing WURCS with their species info
    missing_data = {}
    for _, row in df_bench.iterrows():
        wurcs = row.get('wurcs')
        if pd.isna(wurcs) or wurcs in pretrain_wurcs or wurcs in missing_data:
            continue
        missing_data[wurcs] = {
            'species': row.get('species', 'Unknown'),
            'kingdom': row.get('kingdom', 'Unknown'),
        }
    
    print(f"  Missing unique WURCS: {len(missing_data)}")
    
    # Initialize tokenizer
    print("\n[3/5] Tokenizing missing WURCS...")
    tokenizer = WURCSTokenizer(VOCAB_PATH)
    
    # Process missing WURCS
    new_sequences = {}
    new_index_rows = []
    
    for i, (wurcs, meta) in enumerate(tqdm(missing_data.items(), desc="Tokenizing")):
        try:
            # Tokenize
            result = tokenizer.tokenize(wurcs, max_length=512)
            
            # Create sequences.pkl entry
            key = f"benchmark_{i:05d}"
            new_sequences[key] = {
                'tokens': None,  # Not used for pre-training
                'token_ids': result['token_ids'],
                'length': len([t for t in result['token_ids'] if t != 0]),
                'has_unk_mod': False,
                'wurcs': wurcs,
                'source': 'glycanml_benchmark',
                'residue_ids': result['residue_ids'],
                'has_ms': False,
                'has_3d': False,
                'iupac_name': None,
            }
            
            # Create index entry
            new_index_rows.append({
                'wurcs': wurcs,
                'accession': key,
                'iupac': None,
                'has_sequence': True,
                'has_ms': False,
                'has_structure': False,
                'num_tokens': result['length'],
                'num_residues': result['num_residues'],
                'has_residue_error': False,
                'num_ms_tokens': 0,
                'num_atoms': 0,
                'num_structural_tokens': 0,
                'monosaccharide_names': ','.join(result.get('monosaccharide_names', [])),
            })
            
        except Exception as e:
            print(f"  Warning: Failed to tokenize {wurcs[:50]}... - {e}")
            continue
    
    print(f"  Successfully tokenized: {len(new_sequences)}")
    
    # Merge with existing data
    print("\n[4/5] Merging with existing data...")
    
    # Merge sequences
    augmented_sequences = {**sequences, **new_sequences}
    print(f"  Original sequences: {len(sequences)}")
    print(f"  New sequences: {len(new_sequences)}")
    print(f"  Total sequences: {len(augmented_sequences)}")
    
    # Merge index
    df_new = pd.DataFrame(new_index_rows)
    df_augmented = pd.concat([df_idx, df_new], ignore_index=True)
    print(f"  Original index: {len(df_idx)}")
    print(f"  New index entries: {len(df_new)}")
    print(f"  Total index: {len(df_augmented)}")
    
    # Save augmented data
    print("\n[5/5] Saving augmented data...")
    
    with open(OUTPUT_SEQUENCES, 'wb') as f:
        pickle.dump(augmented_sequences, f)
    print(f"  Saved: {OUTPUT_SEQUENCES}")
    
    df_augmented.to_csv(OUTPUT_INDEX, index=False)
    print(f"  Saved: {OUTPUT_INDEX}")
    
    # Summary
    print("\n" + "="*60)
    print("AUGMENTATION COMPLETE")
    print("="*60)
    print(f"Added {len(new_sequences)} glycans from benchmark")
    print(f"New overlap: 100% (was 39%)")
    print()
    print("To use augmented data, rename files:")
    print("  mv data/sequences.pkl data/sequences_original.pkl")
    print("  mv data/sequences_augmented.pkl data/sequences.pkl")
    print("  mv data/multimodal_index.csv data/multimodal_index_original.csv")
    print("  mv data/multimodal_index_augmented.csv data/multimodal_index.csv")
    print()
    print("Then re-run pre-training:")
    print("  python training/train_multimodal.py --config model/multimodal_config.yaml --restart")


if __name__ == "__main__":
    main()