Spaces:
Sleeping
Sleeping
| """ | |
| STEP 2: Prepare training data from BindingDB | |
| Run: python 2_prepare_data.py | |
| Takes ~10-15 minutes | |
| Reads the BindingDB TSV (uncompressed .tsv or compressed .tsv.gz) and produces | |
| train.jsonl + val.jsonl with protein sequences, pKi values, and binary labels. | |
| """ | |
| import csv | |
| import gzip | |
| import json | |
| import math | |
| import os | |
| import random | |
| import sys | |
| from collections import Counter | |
| from pathlib import Path | |
| DATA_DIR = Path.home() / "alchemy_training" / "data" | |
| # ββ Locate the BindingDB file ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| candidates = [ | |
| DATA_DIR / "bindingdb_raw.tsv", # uncompressed (from .zip extraction) | |
| DATA_DIR / "bindingdb_raw.tsv.gz", # legacy .gz format | |
| ] | |
| # Also check for any BindingDB TSV that might have been extracted | |
| for p in DATA_DIR.glob("BindingDB_All*.tsv"): | |
| candidates.insert(0, p) | |
| INPUT_FILE = None | |
| for c in candidates: | |
| if c.is_file() and c.stat().st_size > 1_000_000: | |
| INPUT_FILE = c | |
| break | |
| if INPUT_FILE is None: | |
| print("β ERROR: BindingDB data file not found in ~/alchemy_training/data/") | |
| print(" Expected one of:") | |
| for c in candidates: | |
| print(f" {c}") | |
| print("\n Run 'bash 1_setup.sh' first, or download manually from:") | |
| print(" https://www.bindingdb.org/rwd/bind/chemsearch/marvin/SDFdownload.jsp?all_download=yes") | |
| sys.exit(1) | |
| print(f"=== Loading BindingDB from {INPUT_FILE.name} ({INPUT_FILE.stat().st_size / 1e9:.2f} GB) ===") | |
| # ββ Open file (handles both .gz and plain .tsv) βββββββββββββββββββββββββββββ | |
| def open_tsv(path): | |
| if str(path).endswith(".gz"): | |
| return gzip.open(path, "rt", encoding="utf-8", errors="ignore") | |
| return open(path, "r", encoding="utf-8", errors="ignore") | |
| # ββ Parse and filter ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| rows = [] | |
| skipped = Counter() | |
| with open_tsv(INPUT_FILE) as f: | |
| reader = csv.DictReader(f, delimiter="\t") | |
| # Verify expected columns exist | |
| if reader.fieldnames is None: | |
| print("β ERROR: Could not read TSV headers. File may be corrupted.") | |
| sys.exit(1) | |
| has_ki = "Ki (nM)" in reader.fieldnames | |
| seq_col = next((c for c in reader.fieldnames if c.startswith("BindingDB Target Chain Sequence")), None) | |
| has_seq = seq_col is not None | |
| has_smi = "Ligand SMILES" in reader.fieldnames | |
| if not has_seq: | |
| print(f"β ERROR: No 'BindingDB Target Chain Sequence' column found") | |
| print(f" Available columns ({len(reader.fieldnames)}):") | |
| for col in reader.fieldnames[:20]: | |
| print(f" - {col}") | |
| sys.exit(1) | |
| if not has_ki: | |
| has_ic50 = "IC50 (nM)" in reader.fieldnames | |
| if has_ic50: | |
| print("β 'Ki (nM)' column not found; using 'IC50 (nM)' as fallback") | |
| else: | |
| print("β ERROR: Neither 'Ki (nM)' nor 'IC50 (nM)' column found") | |
| sys.exit(1) | |
| ki_column = "Ki (nM)" if has_ki else "IC50 (nM)" | |
| print(f" Using affinity column: '{ki_column}'") | |
| print(f" Sequence column: '{seq_col}'") | |
| print(f" Ligand column: 'Ligand SMILES'") | |
| print() | |
| for i, row in enumerate(reader): | |
| if i % 200_000 == 0 and i > 0: | |
| print(f" Scanned {i:,} rows, kept {len(rows):,} ...") | |
| # ββ Extract fields ββ | |
| sequence = row.get(seq_col, "").strip() | |
| smiles = row.get("Ligand SMILES", "").strip() | |
| ki_raw = row.get(ki_column, "").strip() | |
| # ββ Quality filters ββ | |
| if not sequence: | |
| skipped["no_sequence"] += 1 | |
| continue | |
| if not smiles: | |
| skipped["no_smiles"] += 1 | |
| continue | |
| if not ki_raw: | |
| skipped["no_affinity"] += 1 | |
| continue | |
| # Sequence length: ESM2 max is 1022 tokens; skip very short proteins | |
| if len(sequence) < 50: | |
| skipped["sequence_too_short"] += 1 | |
| continue | |
| if len(sequence) > 1022: | |
| skipped["sequence_too_long"] += 1 | |
| continue | |
| # Skip very large ligands (unlikely drug-like) | |
| if len(smiles) > 200: | |
| skipped["smiles_too_long"] += 1 | |
| continue | |
| # ββ Parse Ki/IC50 β pKi ββ | |
| try: | |
| # Remove inequality prefixes like ">", "<", ">=", "~" | |
| cleaned = ki_raw.replace(">", "").replace("<", "").replace("=", "").replace("~", "").strip() | |
| ki_nm = float(cleaned) | |
| if ki_nm <= 0: | |
| skipped["negative_ki"] += 1 | |
| continue | |
| ki_m = ki_nm * 1e-9 # nM β M | |
| pki = -math.log10(ki_m) # pKi = -log10(Ki_M) | |
| # Keep pharmacologically relevant range: pKi 4β12 (10 mM to 1 pM) | |
| if not (4.0 <= pki <= 12.0): | |
| skipped["pki_out_of_range"] += 1 | |
| continue | |
| except (ValueError, OverflowError): | |
| skipped["parse_error"] += 1 | |
| continue | |
| # ββ Keep row ββ | |
| rows.append({ | |
| "sequence": sequence, | |
| "smiles": smiles, | |
| "pki": round(pki, 3), | |
| # Label: 1 = active binder (pKi >= 6 β Ki <= 1 Β΅M), 0 = weak/inactive | |
| "label": 1 if pki >= 6.0 else 0, | |
| }) | |
| # Soft cap at 300K rows to keep training time reasonable (~6-8 hrs) | |
| if len(rows) >= 300_000: | |
| print(f" Reached 300K row cap at row {i:,}") | |
| break | |
| # ββ Report statistics ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| active = sum(r["label"] for r in rows) | |
| inactive = len(rows) - active | |
| print(f"\n{'='*60}") | |
| print(f" Total clean rows: {len(rows):,}") | |
| print(f" Active binders (pKiβ₯6, Kiβ€1Β΅M): {active:,} ({100*active/max(1,len(rows)):.1f}%)") | |
| print(f" Weak binders: {inactive:,} ({100*inactive/max(1,len(rows)):.1f}%)") | |
| print(f"{'='*60}") | |
| if skipped: | |
| print(f"\n Skip reasons:") | |
| for reason, count in sorted(skipped.items(), key=lambda x: -x[1]): | |
| print(f" {reason}: {count:,}") | |
| if len(rows) < 10_000: | |
| print(f"\nβ WARNING: Only {len(rows):,} rows kept. This may not be enough for good training.") | |
| print(" Consider relaxing filters or checking the BindingDB file format.") | |
| if len(rows) == 0: | |
| print("\nβ ERROR: No valid rows found. Cannot proceed.") | |
| sys.exit(1) | |
| # ββ Shuffle and split 90/10 ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| random.seed(42) | |
| random.shuffle(rows) | |
| split = int(len(rows) * 0.9) | |
| train_rows = rows[:split] | |
| val_rows = rows[split:] | |
| # ββ Write JSONL ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| train_file = DATA_DIR / "train.jsonl" | |
| val_file = DATA_DIR / "val.jsonl" | |
| with open(train_file, "w", encoding="utf-8") as f: | |
| for r in train_rows: | |
| f.write(json.dumps(r) + "\n") | |
| with open(val_file, "w", encoding="utf-8") as f: | |
| for r in val_rows: | |
| f.write(json.dumps(r) + "\n") | |
| # ββ Verify files βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| train_size = train_file.stat().st_size / 1e6 | |
| val_size = val_file.stat().st_size / 1e6 | |
| print(f"\n β {len(train_rows):,} train rows β {train_file} ({train_size:.1f} MB)") | |
| print(f" β {len(val_rows):,} val rows β {val_file} ({val_size:.1f} MB)") | |
| # Quick sanity check: read first row back | |
| with open(train_file) as f: | |
| sample = json.loads(f.readline()) | |
| print(f"\n Sample row:") | |
| print(f" sequence: {sample['sequence'][:60]}... ({len(sample['sequence'])} aa)") | |
| print(f" smiles: {sample['smiles'][:60]}") | |
| print(f" pki: {sample['pki']}") | |
| print(f" label: {sample['label']}") | |
| print(f"\n{'='*60}") | |
| print(f" Data ready. Run: python 3_train.py") | |
| print(f"{'='*60}") | |