bertose-affinose-training-code / code /bertint /build_combined_dataset.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
19 kB
"""
Build combined Carbogrove + glycowork dataset for Bertint V5.
Steps:
1. Load Carbogrove processed data (315K records)
2. Load glycowork binding data (wide) β†’ long (620K records)
3. Merge and filter to proteins with sequences
4. Compute per-experiment RANK PERCENTILE as training target
5. Compute auxiliary targets (z-scores) for ablation
6. Save combined dataset + sequence files
Output columns:
protein_id - canonical protein name
glycan_wurcs - WURCS string (for Bertose)
target_raw - original FractionBound or Z-score
target_rank - per-experiment rank percentile [0,1] ← PRIMARY
target_zscore - z-score per source (global) ← ABLATION
target_zscore_exp - z-score within each experiment ← ABLATION
log_conc - log10(concentration+1) or -1 (unknown)
has_conc - 1 if concentration known, 0 if not
source - 'carbogrove' or 'glycowork'
data_source - platform name (CG) or 'glycowork'
exp_id - experiment identifier
exp_size - records in this experiment
Canonical IDs:
- Protein: amino acid sequence (from UniProt or glycowork 'target')
- Glycan: WURCS string (for Bertose embeddings)
"""
import pandas as pd
import numpy as np
import csv
import json
import os
import logging
from pathlib import Path
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
)
logger = logging.getLogger(__name__)
# ── Paths ──────────────────────────────────────────────────────────
BASE = Path(__file__).parent.parent
DATA_DIR = BASE / '19274777-2'
PROCESSED_DIR = BASE / 'processed_v5'
OUTPUT_DIR = BASE / 'combined_dataset'
# ── Step 1: Load Carbogrove ────────────────────────────────────────
def load_carbogrove() -> pd.DataFrame:
"""Load Carbogrove processed data with recovered lectins."""
logger.info("Loading Carbogrove data...")
# Load main binding data
cg = pd.read_csv(PROCESSED_DIR / 'training_data.csv')
logger.info(f" Base Carbogrove: {len(cg):,} records, "
f"{cg['LectinName'].nunique()} lectins")
# Load glycan info for WURCS mapping
glycan_info = {}
with open(DATA_DIR / 'AllGlycanInfo_V2.csv') as f:
for row in csv.DictReader(f):
wurcs = row.get('WURCS', '').strip()
if wurcs:
glycan_info[row['GlycanID']] = wurcs
# Load ALL protein sequences:
# 1. UniProt-fetched sequences (most complete)
prot_seqs = {}
cg_seq_path = OUTPUT_DIR / 'carbogrove_protein_sequences.csv'
if cg_seq_path.exists():
cg_seqs = pd.read_csv(cg_seq_path)
for _, row in cg_seqs.iterrows():
prot_seqs[row['protein_id']] = {
'sequence': row['sequence'],
'source': row['source'],
}
logger.info(f" Loaded {len(prot_seqs)} Carbogrove protein sequences")
# 2. Recovered sequences (lectins without UniProt IDs)
recovered_path = DATA_DIR / 'recovered_sequences.csv'
if recovered_path.exists():
recovered = pd.read_csv(recovered_path)
recovered = recovered[recovered['Status'] == 'RECOVERED']
for _, row in recovered.iterrows():
name = row['LectinName']
if name not in prot_seqs and row['Sequence'] and len(str(row['Sequence'])) > 20:
prot_seqs[name] = {
'sequence': row['Sequence'],
'source': row['Source'],
}
n_recovered = len(recovered)
logger.info(f" Added {n_recovered} recovered lectin sequences")
# Add WURCS to Carbogrove
cg['WURCS'] = cg['GlycanID'].map(glycan_info)
n_with_wurcs = cg['WURCS'].notna().sum()
logger.info(f" Records with WURCS: {n_with_wurcs:,} / {len(cg):,}")
# Drop records without WURCS (can't generate Bertose embeddings)
cg = cg[cg['WURCS'].notna()].copy()
logger.info(f" After WURCS filter: {len(cg):,} records")
# Load experiment metadata for DataSource (platform)
data_info = {}
with open(DATA_DIR / 'AllDataInfo_V2.csv') as f:
for row in csv.DictReader(f):
data_info[int(row['DataID'])] = row.get('DataSource', 'unknown')
cg['data_source'] = cg['DataID'].map(data_info)
logger.info(f" DataSources: {cg['data_source'].nunique()} platforms")
# Standardize columns (keep DataID for per-experiment rank)
cg['source'] = 'carbogrove'
cg['target_raw'] = cg['FractionBound']
cg['log_conc'] = np.log10(cg['Concentration'] + 1)
cg['has_conc'] = 1
cg['exp_id'] = 'cg_' + cg['DataID'].astype(str)
return cg, prot_seqs, glycan_info
# ── Step 2: Load & convert glycowork ──────────────────────────────
def load_glycowork(iupac_to_wurcs: dict) -> pd.DataFrame:
"""
Load glycowork binding data, convert wide→long,
filter to glycans with WURCS mappings.
"""
logger.info("Loading glycowork data...")
from glycowork.glycan_data.loader import glycan_binding
gw = glycan_binding
meta_cols = ['protein', 'target']
glycan_cols = [c for c in gw.columns if c not in meta_cols]
logger.info(f" Wide format: {gw.shape[0]} proteins Γ— {len(glycan_cols)} glycans")
# Filter to glycans with WURCS
mapped_glycans = [g for g in glycan_cols if g in iupac_to_wurcs]
logger.info(f" Glycans with WURCS: {len(mapped_glycans)} / {len(glycan_cols)}")
# Wide β†’ Long (only for mapped glycans)
logger.info(" Converting wide β†’ long format...")
records = []
skipped_nan_protein = 0
skipped_short_seq = 0
for _, row in gw.iterrows():
protein_name = row['protein']
protein_seq = row['target']
# Skip rows with NaN protein name
if pd.isna(protein_name) or str(protein_name).strip() == 'nan':
skipped_nan_protein += 1
continue
protein_name = str(protein_name).strip()
protein_seq = str(protein_seq) if pd.notna(protein_seq) else ''
# Skip proteins without sequences
if len(protein_seq) < 50:
skipped_short_seq += 1
continue
for glycan_iupac in mapped_glycans:
val = row[glycan_iupac]
if pd.notna(val):
records.append({
'protein_name': protein_name,
'protein_seq': protein_seq,
'glycan_iupac': glycan_iupac,
'WURCS': iupac_to_wurcs[glycan_iupac],
'target_raw': float(val),
'target_zscore': float(val), # Already z-scored
})
logger.info(f" Skipped {skipped_nan_protein} rows with NaN protein")
logger.info(f" Skipped {skipped_short_seq} rows with short/no sequence")
gw_long = pd.DataFrame(records)
logger.info(f" Long format: {len(gw_long):,} records")
logger.info(f" Proteins: {gw_long['protein_name'].nunique()}")
logger.info(f" Glycans: {gw_long['glycan_iupac'].nunique()}")
gw_long['source'] = 'glycowork'
gw_long['log_conc'] = -1.0 # Sentinel: concentration unknown
gw_long['has_conc'] = 0
gw_long['data_source'] = 'glycowork'
gw_long['exp_id'] = 'gw_' + gw_long['protein_name']
return gw_long
# ── Step 3: Merge datasets ────────────────────────────────────────
def merge_datasets(
cg: pd.DataFrame,
gw: pd.DataFrame,
prot_seqs: dict,
) -> pd.DataFrame:
"""
Merge Carbogrove and glycowork into unified format.
Unified columns:
- protein_id: canonical ID (UniProt or glycowork name)
- protein_seq: amino acid sequence (for ESM-C)
- glycan_wurcs: WURCS string (for Bertose)
- target_raw: original target value
- target_zscore: z-scored target (per source)
- log_conc: log10(concentration + 1) or NaN
- source: 'carbogrove' or 'glycowork'
"""
logger.info("Merging datasets...")
# Standardize Carbogrove columns
cg_unified = pd.DataFrame({
'protein_id': cg['LectinName'],
'glycan_wurcs': cg['WURCS'],
'target_raw': cg['target_raw'],
'log_conc': cg['log_conc'],
'has_conc': cg['has_conc'],
'source': cg['source'],
'data_source': cg['data_source'],
'exp_id': cg['exp_id'],
})
# Standardize glycowork columns
gw_unified = pd.DataFrame({
'protein_id': gw['protein_name'],
'glycan_wurcs': gw['WURCS'],
'target_raw': gw['target_raw'],
'log_conc': gw['log_conc'],
'has_conc': gw['has_conc'],
'source': gw['source'],
'data_source': gw['data_source'],
'exp_id': gw['exp_id'],
})
# Concatenate
combined = pd.concat([cg_unified, gw_unified], ignore_index=True)
logger.info(f" Combined (pre-dedup): {len(combined):,} records")
logger.info(f" Carbogrove: {(combined['source'] == 'carbogrove').sum():,}")
logger.info(f" glycowork: {(combined['source'] == 'glycowork').sum():,}")
# ── Deduplication ──────────────────────────────────────────────
# Two sources of duplicates:
# 1. glycowork: 160 proteins have multiple rows in the binding
# matrix (same protein tested on different arrays), creating
# duplicate (protein, glycan) pairs with different Z-scores.
# 2. Carbogrove: different GlycanIDs can map to the same WURCS
# (structurally identical glycans with different printed names),
# creating duplicate (protein, glycan, experiment) entries.
#
# Fix: average target_raw per (protein_id, glycan_wurcs, exp_id).
# Other columns (log_conc, has_conc, source, data_source) are
# constant within each group, so we take the first value.
group_cols = ['protein_id', 'glycan_wurcs', 'exp_id']
before = len(combined)
combined = combined.groupby(group_cols, as_index=False).agg({
'target_raw': 'mean', # Average replicate measurements
'log_conc': 'first', # Constant within experiment
'has_conc': 'first', # Constant within source
'source': 'first', # Constant within experiment
'data_source': 'first', # Constant within experiment
})
after = len(combined)
removed = before - after
logger.info(f" Deduplicated: {before:,} β†’ {after:,} ({removed:,} duplicates removed)")
return combined
# ── Step 4: Generate protein sequence file ─────────────────────────
def build_protein_sequence_file(
combined: pd.DataFrame,
prot_seqs: dict,
gw_long: pd.DataFrame,
) -> pd.DataFrame:
"""Build protein_id β†’ sequence mapping for ESM-C embedding generation."""
logger.info("Building protein sequence file...")
seq_map = {}
# Carbogrove proteins: from UniProt-fetched + recovered
for name, info in prot_seqs.items():
if 'sequence' in info and len(str(info['sequence'])) > 20:
seq_map[name] = str(info['sequence'])
n_cg = len(seq_map)
logger.info(f" Carbogrove sequences: {n_cg}")
# glycowork proteins: sequence in data
gw_seqs = gw_long.groupby('protein_name')['protein_seq'].first()
for name, seq in gw_seqs.items():
if name not in seq_map and len(str(seq)) > 50:
seq_map[name] = str(seq)
logger.info(f" glycowork sequences added: {len(seq_map) - n_cg}")
# Check coverage
all_proteins = set(combined['protein_id'].unique())
covered = all_proteins & set(seq_map.keys())
missing = all_proteins - set(seq_map.keys())
logger.info(f" Total unique proteins: {len(all_proteins)}")
logger.info(f" With sequences: {len(covered)} ({len(covered)/len(all_proteins)*100:.1f}%)")
logger.info(f" Missing sequences: {len(missing)}")
if missing:
logger.info(f" Missing: {sorted(missing)}")
# Drop records for proteins without sequences from combined
records_before = len(combined)
combined_filtered = combined[combined['protein_id'].isin(seq_map)].copy()
records_after = len(combined_filtered)
dropped = records_before - records_after
if dropped > 0:
logger.info(f" Dropped {dropped:,} records with missing protein sequences")
# Save
seq_df = pd.DataFrame([
{'protein_id': pid, 'sequence': seq, 'seq_len': len(seq)}
for pid, seq in seq_map.items()
if pid in set(combined_filtered['protein_id'].unique())
])
return seq_df, combined_filtered
# ── Step 5: Rank-based normalization ───────────────────────────────
def compute_rank_targets(combined: pd.DataFrame) -> pd.DataFrame:
"""
Compute per-experiment rank percentile + auxiliary z-score targets.
Targets computed:
target_rank - rank percentile [0,1] within experiment (PRIMARY)
target_zscore - z-score per source (global) (ABLATION)
target_zscore_exp - z-score within each experiment (ABLATION)
exp_size - number of records in this experiment
"""
logger.info("Computing targets...")
# 1. Rank within each experiment, normalized to [0, 1]
combined['target_rank'] = combined.groupby('exp_id')['target_raw'].transform(
lambda x: x.rank(method='average', pct=True)
)
# 2. Global z-score per source
for src in ['carbogrove', 'glycowork']:
mask = combined['source'] == src
raw = combined.loc[mask, 'target_raw']
combined.loc[mask, 'target_zscore'] = (
(raw - raw.mean()) / (raw.std() + 1e-10)
)
logger.info(f" Global z-scores computed per source")
# 3. Z-score within each experiment
combined['target_zscore_exp'] = combined.groupby('exp_id')['target_raw'].transform(
lambda x: (x - x.mean()) / (x.std() + 1e-10) if x.std() > 0 else 0.0
)
logger.info(f" Per-experiment z-scores computed")
# 4. Experiment size (for potential weighting)
combined['exp_size'] = combined.groupby('exp_id')['target_raw'].transform('count').astype(int)
# Validation: check rank distribution
n_exps = combined['exp_id'].nunique()
exp_sizes = combined.groupby('exp_id').size()
logger.info(f" Ranked across {n_exps:,} experiments")
logger.info(f" Experiment sizes: min={exp_sizes.min()}, "
f"median={exp_sizes.median():.0f}, max={exp_sizes.max()}")
# Drop experiments with fewer than 5 glycans β€” too noisy for ranking
small_exps = (exp_sizes < 5).sum()
if small_exps > 0:
valid_exps = exp_sizes[exp_sizes >= 5].index
before = len(combined)
combined = combined[combined['exp_id'].isin(valid_exps)].copy()
after = len(combined)
if before > after:
logger.info(f" Dropped {before - after:,} records from "
f"experiments with <5 glycans")
# Per-source stats
for src in ['carbogrove', 'glycowork']:
sub = combined[combined['source'] == src]
r = sub['target_rank']
z = sub['target_zscore']
logger.info(f" [{src}] rank: mean={r.mean():.3f} std={r.std():.3f} | "
f"zscore: mean={z.mean():.3f} std={z.std():.3f}")
return combined
# ── Step 7: Summary statistics ─────────────────────────────────────
def print_summary(combined: pd.DataFrame) -> None:
"""Print comprehensive dataset summary."""
logger.info("=" * 70)
logger.info("COMBINED DATASET SUMMARY")
logger.info("=" * 70)
for src in ['carbogrove', 'glycowork', 'all']:
if src == 'all':
sub = combined
label = 'COMBINED'
else:
sub = combined[combined['source'] == src]
label = src.upper()
logger.info(f"\n [{label}]")
logger.info(f" Records: {len(sub):,}")
logger.info(f" Proteins: {sub['protein_id'].nunique()}")
logger.info(f" Glycans (WURCS): {sub['glycan_wurcs'].nunique()}")
logger.info(f" Experiments: {sub['exp_id'].nunique()}")
logger.info(f" target_rank: mean={sub['target_rank'].mean():.3f}, "
f"std={sub['target_rank'].std():.3f}")
logger.info(f" target_zscore: mean={sub['target_zscore'].mean():.3f}, "
f"std={sub['target_zscore'].std():.3f}")
# Column inventory
logger.info(f"\n Columns: {list(combined.columns)}")
logger.info(f" Unique glycan WURCS: {combined['glycan_wurcs'].nunique()}")
logger.info(f" Unique protein IDs: {combined['protein_id'].nunique()}")
# ── Main ───────────────────────────────────────────────────────────
def main() -> None:
"""Build combined dataset."""
# Create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# Step 1: Load Carbogrove
cg, prot_seqs, glycan_info = load_carbogrove()
# Load IUPAC β†’ WURCS mapping
iupac_wurcs_path = DATA_DIR / 'glycowork_iupac_to_wurcs.json'
with open(iupac_wurcs_path) as f:
iupac_to_wurcs = json.load(f)
logger.info(f"Loaded {len(iupac_to_wurcs)} IUPAC→WURCS mappings")
# Step 2: Load & convert glycowork
gw = load_glycowork(iupac_to_wurcs)
# Step 3: Merge
combined = merge_datasets(cg, gw, prot_seqs)
# Step 4: Protein sequences (also filters combined to drop missing)
seq_df, combined_final = build_protein_sequence_file(combined, prot_seqs, gw)
# Step 5: Rank-based normalization + auxiliary targets
combined_final = compute_rank_targets(combined_final)
# Save outputs
logger.info("\nSaving outputs...")
combined_final.to_csv(OUTPUT_DIR / 'combined_binding_data.csv', index=False)
logger.info(f" Saved combined_binding_data.csv: {len(combined_final):,} records")
seq_df.to_csv(OUTPUT_DIR / 'protein_sequences.csv', index=False)
logger.info(f" Saved protein_sequences.csv: {len(seq_df)} proteins")
# Save unique WURCS for Bertose embedding generation
unique_wurcs = combined_final['glycan_wurcs'].unique()
wurcs_df = pd.DataFrame({'wurcs': unique_wurcs})
wurcs_df.to_csv(OUTPUT_DIR / 'unique_glycan_wurcs.csv', index=False)
logger.info(f" Saved unique_glycan_wurcs.csv: {len(wurcs_df)} glycans")
# Step 7: Summary
print_summary(combined_final)
if __name__ == '__main__':
main()