Commit
·
6491864
1
Parent(s):
1a47b7e
Initial deployment: Antibody non-specificity predictor
Browse files- ESM-1v (650M) + Logistic Regression
- Trained on Boughter dataset
- Pydantic v2 validation
- Gradio 5.x UI
This view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +62 -7
- app.py +152 -0
- experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl +3 -0
- pyproject.toml +215 -0
- requirements.txt +28 -0
- src/antibody_training_esm/__init__.py +0 -0
- src/antibody_training_esm/__pycache__/__init__.cpython-312.pyc +0 -0
- src/antibody_training_esm/__pycache__/settings.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/__init__.py +10 -0
- src/antibody_training_esm/cli/__pycache__/__init__.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/__pycache__/app.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/__pycache__/predict.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/__pycache__/preprocess.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/__pycache__/test.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/__pycache__/train.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/app.py +197 -0
- src/antibody_training_esm/cli/predict.py +116 -0
- src/antibody_training_esm/cli/preprocess.py +84 -0
- src/antibody_training_esm/cli/test.py +155 -0
- src/antibody_training_esm/cli/testing/__init__.py +1 -0
- src/antibody_training_esm/cli/testing/__pycache__/__init__.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/testing/__pycache__/config.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/testing/__pycache__/data.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/testing/__pycache__/evaluation.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/testing/__pycache__/tester.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/testing/__pycache__/visualization.cpython-312.pyc +0 -0
- src/antibody_training_esm/cli/testing/config.py +62 -0
- src/antibody_training_esm/cli/testing/data.py +73 -0
- src/antibody_training_esm/cli/testing/evaluation.py +134 -0
- src/antibody_training_esm/cli/testing/tester.py +384 -0
- src/antibody_training_esm/cli/testing/visualization.py +127 -0
- src/antibody_training_esm/cli/train.py +42 -0
- src/antibody_training_esm/conf/__init__.py +9 -0
- src/antibody_training_esm/conf/__pycache__/__init__.cpython-312.pyc +0 -0
- src/antibody_training_esm/conf/__pycache__/config_schema.cpython-312.pyc +0 -0
- src/antibody_training_esm/conf/classifier/logreg.yaml +12 -0
- src/antibody_training_esm/conf/classifier/xgboost.yaml +14 -0
- src/antibody_training_esm/conf/config.yaml +36 -0
- src/antibody_training_esm/conf/config_schema.py +142 -0
- src/antibody_training_esm/conf/data/boughter_jain.yaml +23 -0
- src/antibody_training_esm/conf/hardware/default.yaml +5 -0
- src/antibody_training_esm/conf/hydra/default.yaml +10 -0
- src/antibody_training_esm/conf/model/esm1v.yaml +4 -0
- src/antibody_training_esm/conf/model/esm2_650m.yaml +3 -0
- src/antibody_training_esm/conf/predict.yaml +26 -0
- src/antibody_training_esm/conf/testing/jain_p5e_s2.yaml +7 -0
- src/antibody_training_esm/core/__init__.py +19 -0
- src/antibody_training_esm/core/__pycache__/__init__.cpython-312.pyc +0 -0
- src/antibody_training_esm/core/__pycache__/classifier.cpython-312.pyc +0 -0
- src/antibody_training_esm/core/__pycache__/classifier_factory.cpython-312.pyc +0 -0
README.md
CHANGED
|
@@ -1,12 +1,67 @@
|
|
| 1 |
---
|
| 2 |
-
title: Antibody Predictor
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Antibody Non-Specificity Predictor
|
| 3 |
+
emoji: 🧬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "5.0.0"
|
| 8 |
+
app_file: spaces/app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
tags:
|
| 12 |
+
- antibody
|
| 13 |
+
- protein
|
| 14 |
+
- ESM
|
| 15 |
+
- gradio
|
| 16 |
+
- polyreactivity
|
| 17 |
+
- machine-learning
|
| 18 |
---
|
| 19 |
|
| 20 |
+
# 🧬 Antibody Non-Specificity Predictor
|
| 21 |
+
|
| 22 |
+
Predict antibody polyreactivity (non-specificity) from Variable Heavy (VH) or Variable Light (VL) sequences using ESM-1v protein language models.
|
| 23 |
+
|
| 24 |
+
## Model
|
| 25 |
+
|
| 26 |
+
- **Architecture:** ESM-1v (650M parameters) + Logistic Regression
|
| 27 |
+
- **Training Data:** Boughter dataset (914 antibodies, ELISA polyreactivity)
|
| 28 |
+
- **Methodology:** Sakhnini et al. (2025) - Prediction of Antibody Non-Specificity using PLMs
|
| 29 |
+
|
| 30 |
+
## Usage
|
| 31 |
+
|
| 32 |
+
1. Paste your antibody VH or VL amino acid sequence
|
| 33 |
+
2. Click "🔬 Predict Non-Specificity"
|
| 34 |
+
3. Get prediction (specific vs non-specific) + probability
|
| 35 |
+
|
| 36 |
+
## Supported Input
|
| 37 |
+
|
| 38 |
+
- **Valid characters:** Standard amino acids (ACDEFGHIKLMNPQRSTVWY)
|
| 39 |
+
- **Max length:** 2000 amino acids
|
| 40 |
+
- **Auto-cleaning:** Lowercase automatically converted to uppercase
|
| 41 |
+
|
| 42 |
+
## Examples
|
| 43 |
+
|
| 44 |
+
The app includes example sequences:
|
| 45 |
+
- Standard VH (128aa)
|
| 46 |
+
- Standard VL (107aa)
|
| 47 |
+
- Short VH (Herceptin-like)
|
| 48 |
+
|
| 49 |
+
## Citation
|
| 50 |
+
|
| 51 |
+
If you use this tool in your research, please cite:
|
| 52 |
+
|
| 53 |
+
```bibtex
|
| 54 |
+
@article{sakhnini2025antibody,
|
| 55 |
+
title={Prediction of Antibody Non-Specificity using Protein Language Models},
|
| 56 |
+
author={Sakhnini, et al.},
|
| 57 |
+
year={2025}
|
| 58 |
+
}
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## Repository
|
| 62 |
+
|
| 63 |
+
Full source code: [antibody_training_pipeline_ESM](https://github.com/The-Obstacle-Is-The-Way/antibody_training_pipeline_ESM)
|
| 64 |
+
|
| 65 |
+
## License
|
| 66 |
+
|
| 67 |
+
MIT License - See repository for details
|
app.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Spaces Gradio App for Antibody Non-Specificity Prediction
|
| 3 |
+
|
| 4 |
+
Simplified deployment version (no Hydra, no complex dependencies).
|
| 5 |
+
Works on HF Spaces free CPU tier.
|
| 6 |
+
|
| 7 |
+
Local app (src/antibody_training_esm/cli/app.py) remains unchanged.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import torch
|
| 15 |
+
from pydantic import ValidationError
|
| 16 |
+
|
| 17 |
+
from antibody_training_esm.core.prediction import Predictor
|
| 18 |
+
from antibody_training_esm.models.prediction import PredictionRequest
|
| 19 |
+
|
| 20 |
+
# Configure logging
|
| 21 |
+
logging.basicConfig(level=logging.INFO)
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
# HF Spaces environment detection
|
| 25 |
+
IS_HF_SPACE = os.getenv("SPACE_ID") is not None
|
| 26 |
+
|
| 27 |
+
# Model path (either local or downloaded from HF Hub)
|
| 28 |
+
MODEL_PATH = os.getenv(
|
| 29 |
+
"MODEL_PATH", "experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# ESM model name
|
| 33 |
+
MODEL_NAME = "facebook/esm1v_t33_650M_UR90S_1"
|
| 34 |
+
|
| 35 |
+
# Force CPU for HF Spaces free tier
|
| 36 |
+
DEVICE = "cpu"
|
| 37 |
+
|
| 38 |
+
# Load model globally (HF Spaces best practice)
|
| 39 |
+
logger.info(f"Loading model from {MODEL_PATH}...")
|
| 40 |
+
predictor = Predictor(
|
| 41 |
+
model_name=MODEL_NAME, classifier_path=MODEL_PATH, device=DEVICE, config_path=None
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Warm up model
|
| 45 |
+
try:
|
| 46 |
+
logger.info("Warming up model...")
|
| 47 |
+
predictor.predict_single("QVQL")
|
| 48 |
+
logger.info("Model ready!")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logger.warning(f"Warmup failed (non-fatal): {e}")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def predict_sequence(sequence: str) -> tuple[str, str]:
|
| 54 |
+
"""
|
| 55 |
+
Prediction function for Gradio interface.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
sequence: Antibody amino acid sequence
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Tuple of (prediction, probability)
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
# Validate with Pydantic
|
| 65 |
+
request = PredictionRequest(sequence=sequence)
|
| 66 |
+
|
| 67 |
+
# Log request
|
| 68 |
+
logger.info(f"Processing sequence: length={len(request.sequence)}")
|
| 69 |
+
|
| 70 |
+
# Predict
|
| 71 |
+
result = predictor.predict_single(request)
|
| 72 |
+
|
| 73 |
+
# Format probability
|
| 74 |
+
prob_percent = f"{result.probability:.1%}"
|
| 75 |
+
|
| 76 |
+
return result.prediction, prob_percent
|
| 77 |
+
|
| 78 |
+
except ValidationError as e:
|
| 79 |
+
# User-friendly error message
|
| 80 |
+
error_msg = e.errors()[0]["msg"]
|
| 81 |
+
raise gr.Error(error_msg) from e
|
| 82 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 83 |
+
logger.error("GPU OOM during inference")
|
| 84 |
+
raise gr.Error(
|
| 85 |
+
"Server overloaded (GPU OOM). Please try again in a moment."
|
| 86 |
+
) from e
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.exception("Unexpected prediction failure")
|
| 89 |
+
raise gr.Error(f"Prediction failed: {str(e)}") from e
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Example sequences
|
| 93 |
+
examples = [
|
| 94 |
+
[
|
| 95 |
+
"QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYNMHWVRQAPGQGLEWMGGIYPGDSDTRYSPSFQGQVTISADKSISTAYLQWSSLKASDTAMYYCARSTYYGGDWYFNVWGQGTLVTVSS"
|
| 96 |
+
], # Standard VH
|
| 97 |
+
[
|
| 98 |
+
"DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPLTFGGGTKVEIK"
|
| 99 |
+
], # Standard VL
|
| 100 |
+
[
|
| 101 |
+
"EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARSWGQGTLVTVSS"
|
| 102 |
+
], # Short VH
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
# Create Gradio interface
|
| 106 |
+
iface = gr.Interface(
|
| 107 |
+
fn=predict_sequence,
|
| 108 |
+
inputs=gr.TextArea(
|
| 109 |
+
lines=7,
|
| 110 |
+
max_lines=20,
|
| 111 |
+
max_length=2000,
|
| 112 |
+
label="Antibody Sequence (VH or VL)",
|
| 113 |
+
placeholder="Paste amino acid sequence here (e.g., QVQL...)",
|
| 114 |
+
info="Supported characters: Standard amino acids (ACDEFGHIKLMNPQRSTVWY).",
|
| 115 |
+
show_copy_button=True,
|
| 116 |
+
),
|
| 117 |
+
outputs=[
|
| 118 |
+
gr.Textbox(label="Prediction", show_copy_button=True),
|
| 119 |
+
gr.Textbox(label="Probability of Non-Specificity", show_copy_button=True),
|
| 120 |
+
],
|
| 121 |
+
title="🧬 Antibody Non-Specificity Predictor",
|
| 122 |
+
description=(
|
| 123 |
+
"Predict antibody polyreactivity (non-specificity) from Variable Heavy (VH) "
|
| 124 |
+
"or Variable Light (VL) sequences using ESM-1v protein language models.\n\n"
|
| 125 |
+
"**Model:** ESM-1v (650M parameters) + Logistic Regression\n"
|
| 126 |
+
"**Training:** Boughter dataset (914 antibodies, ELISA polyreactivity)\n"
|
| 127 |
+
"**Citation:** Sakhnini et al. (2025) - Prediction of Antibody Non-Specificity using PLMs"
|
| 128 |
+
),
|
| 129 |
+
article=(
|
| 130 |
+
f"**Model:** {MODEL_NAME}\n"
|
| 131 |
+
f"**Device:** {DEVICE}\n"
|
| 132 |
+
f"**Environment:** {'Hugging Face Spaces' if IS_HF_SPACE else 'Local'}"
|
| 133 |
+
),
|
| 134 |
+
examples=examples,
|
| 135 |
+
cache_examples=False, # Don't cache on HF Spaces (saves disk)
|
| 136 |
+
flagging_mode="never",
|
| 137 |
+
analytics_enabled=False,
|
| 138 |
+
submit_btn="🔬 Predict Non-Specificity",
|
| 139 |
+
clear_btn="🗑️ Clear",
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Enable queue for concurrency
|
| 143 |
+
iface.queue(default_concurrency_limit=2, max_size=10)
|
| 144 |
+
|
| 145 |
+
# Launch app
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
iface.launch(
|
| 148 |
+
server_name="0.0.0.0", # Required for HF Spaces
|
| 149 |
+
server_port=7860,
|
| 150 |
+
share=False,
|
| 151 |
+
show_api=False, # No public REST API
|
| 152 |
+
)
|
experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4f77cadfd0ccf3a12c24ce142a91c82b4481d5153a0af662ac4b05a78ef6670
|
| 3 |
+
size 11314
|
pyproject.toml
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[tool.hatch.build.targets.wheel]
|
| 6 |
+
packages = ["src/antibody_training_esm"]
|
| 7 |
+
include = [
|
| 8 |
+
"src/antibody_training_esm/conf/**/*.yaml",
|
| 9 |
+
"src/antibody_training_esm/conf/**/*.py",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
[tool.hatch.build.targets.sdist]
|
| 13 |
+
# Source distribution must include all source files + configs
|
| 14 |
+
include = [
|
| 15 |
+
"src/antibody_training_esm/**/*.py",
|
| 16 |
+
"src/antibody_training_esm/conf/**/*.yaml",
|
| 17 |
+
"tests/**/*.py",
|
| 18 |
+
"README.md",
|
| 19 |
+
"pyproject.toml",
|
| 20 |
+
"LICENSE",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
[project]
|
| 24 |
+
name = "antibody-training-esm"
|
| 25 |
+
version = "0.7.0"
|
| 26 |
+
description = "Professional antibody training pipeline using ESM protein language models"
|
| 27 |
+
license = {text = "Apache-2.0"}
|
| 28 |
+
requires-python = ">=3.12"
|
| 29 |
+
dependencies = [
|
| 30 |
+
"authlib>=1.6.5",
|
| 31 |
+
"biopython>=1.80",
|
| 32 |
+
"brotli>=1.2.0",
|
| 33 |
+
"datasets>=4.2.0",
|
| 34 |
+
"h2>=4.3.0",
|
| 35 |
+
"hydra-core>=1.3.2",
|
| 36 |
+
"jupyterlab>=4.4.9",
|
| 37 |
+
"matplotlib>=3.7.0",
|
| 38 |
+
"more-itertools",
|
| 39 |
+
"numpy>=1.24.0",
|
| 40 |
+
"pandas>=2.0.0",
|
| 41 |
+
"plotly",
|
| 42 |
+
"pyparsing>=3.0.0",
|
| 43 |
+
"PyYAML>=6.0.0",
|
| 44 |
+
"riot_na",
|
| 45 |
+
"scikit-learn>=1.3.0",
|
| 46 |
+
"scipy>=1.10.0",
|
| 47 |
+
"seaborn>=0.12.0",
|
| 48 |
+
"torch>=2.6.0",
|
| 49 |
+
"tqdm>=4.65.0",
|
| 50 |
+
"transformers>=4.30.0",
|
| 51 |
+
"xgboost>=2.0.0",
|
| 52 |
+
"gradio>=4.0.0",
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
[project.optional-dependencies]
|
| 56 |
+
validation = [
|
| 57 |
+
"pydantic>=2.10.0", # Stable v2 release
|
| 58 |
+
"pydantic-settings>=2.6.0", # For future config management
|
| 59 |
+
"pandera>=0.20.0", # Phase 3: Data Integrity
|
| 60 |
+
]
|
| 61 |
+
dev = [
|
| 62 |
+
# Testing
|
| 63 |
+
"pytest>=8.3.0",
|
| 64 |
+
"pytest-cov>=6.0.0",
|
| 65 |
+
"pytest-xdist>=3.6.0",
|
| 66 |
+
"pytest-sugar>=1.0.0",
|
| 67 |
+
|
| 68 |
+
# Linting & Formatting
|
| 69 |
+
"ruff>=0.8.0",
|
| 70 |
+
|
| 71 |
+
# Type Checking
|
| 72 |
+
"mypy>=1.13.0",
|
| 73 |
+
"pandas-stubs>=2.2.0",
|
| 74 |
+
|
| 75 |
+
# Security
|
| 76 |
+
"bandit[toml]>=1.7.0",
|
| 77 |
+
|
| 78 |
+
# Pre-commit
|
| 79 |
+
"pre-commit>=4.0.0",
|
| 80 |
+
|
| 81 |
+
# Documentation
|
| 82 |
+
"mkdocs>=1.6.0",
|
| 83 |
+
"mkdocs-material>=9.5.0",
|
| 84 |
+
"mkdocstrings[python]>=0.26.0",
|
| 85 |
+
"mkdocs-gen-files>=0.5.0",
|
| 86 |
+
"mkdocs-literate-nav>=0.6.0",
|
| 87 |
+
"mkdocs-section-index>=0.3.0",
|
| 88 |
+
"pymdown-extensions>=10.0.0",
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
[project.scripts]
|
| 92 |
+
# Point directly to Hydra-decorated function to enable config group overrides
|
| 93 |
+
# (antibody-train model=esm2_650m classifier=xgboost now works correctly)
|
| 94 |
+
antibody-train = "antibody_training_esm.core.trainer:main"
|
| 95 |
+
antibody-test = "antibody_training_esm.cli.test:main"
|
| 96 |
+
antibody-preprocess = "antibody_training_esm.cli.preprocess:main"
|
| 97 |
+
antibody-predict = "antibody_training_esm.cli.predict:main"
|
| 98 |
+
antibody-app = "antibody_training_esm.cli.app:main"
|
| 99 |
+
|
| 100 |
+
[tool.ruff]
|
| 101 |
+
target-version = "py312"
|
| 102 |
+
line-length = 88
|
| 103 |
+
|
| 104 |
+
[tool.ruff.lint]
|
| 105 |
+
select = [
|
| 106 |
+
"E", # pycodestyle errors
|
| 107 |
+
"W", # pycodestyle warnings
|
| 108 |
+
"F", # pyflakes
|
| 109 |
+
"I", # isort
|
| 110 |
+
"B", # flake8-bugbear
|
| 111 |
+
"C4", # flake8-comprehensions
|
| 112 |
+
"UP", # pyupgrade
|
| 113 |
+
"ARG", # flake8-unused-arguments
|
| 114 |
+
"SIM", # flake8-simplify
|
| 115 |
+
]
|
| 116 |
+
ignore = [
|
| 117 |
+
"E501", # line too long (handled by formatter)
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
[tool.ruff.lint.per-file-ignores]
|
| 121 |
+
"__init__.py" = ["F401"]
|
| 122 |
+
"tests/**/*" = ["ARG"]
|
| 123 |
+
"experiments/**/*" = ["ALL"]
|
| 124 |
+
"reference_repos/**/*" = ["ALL"]
|
| 125 |
+
|
| 126 |
+
[tool.ruff.format]
|
| 127 |
+
quote-style = "double"
|
| 128 |
+
indent-style = "space"
|
| 129 |
+
|
| 130 |
+
[tool.mypy]
|
| 131 |
+
python_version = "3.12"
|
| 132 |
+
warn_return_any = true
|
| 133 |
+
warn_unused_configs = true
|
| 134 |
+
disallow_untyped_defs = true
|
| 135 |
+
ignore_missing_imports = true
|
| 136 |
+
exclude = [
|
| 137 |
+
"experiments/",
|
| 138 |
+
"reference_repos/",
|
| 139 |
+
"site/", # MkDocs generated documentation
|
| 140 |
+
"tests/unit/cli/test_train.py", # Legacy CLI tests (deprecated)
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
[tool.pytest.ini_options]
|
| 144 |
+
# Pytest Configuration (canonical source - pytest.ini deleted for single source of truth)
|
| 145 |
+
testpaths = ["tests"]
|
| 146 |
+
python_files = ["test_*.py"]
|
| 147 |
+
python_classes = ["Test*"]
|
| 148 |
+
python_functions = ["test_*"]
|
| 149 |
+
addopts = [
|
| 150 |
+
# Output formatting
|
| 151 |
+
"-v",
|
| 152 |
+
"--tb=short",
|
| 153 |
+
"--strict-markers",
|
| 154 |
+
"-ra",
|
| 155 |
+
# Coverage reporting
|
| 156 |
+
"--cov=src/antibody_training_esm",
|
| 157 |
+
"--cov-report=html",
|
| 158 |
+
"--cov-report=term-missing",
|
| 159 |
+
# Performance
|
| 160 |
+
"--maxfail=10",
|
| 161 |
+
]
|
| 162 |
+
markers = [
|
| 163 |
+
"unit: Unit tests (fast, no I/O) - Core business logic",
|
| 164 |
+
"integration: Integration tests (medium speed, some I/O) - Component interactions",
|
| 165 |
+
"e2e: End-to-end tests (slow, full pipeline) - Full workflows",
|
| 166 |
+
"slow: Tests that take >1s to run",
|
| 167 |
+
"gpu: Tests that require GPU (skip in CI with: -m 'not gpu')",
|
| 168 |
+
"legacy: Legacy tests for backward compatibility (deprecated, will be removed)",
|
| 169 |
+
]
|
| 170 |
+
filterwarnings = [
|
| 171 |
+
# sklearn deprecation warnings
|
| 172 |
+
"ignore:.*__sklearn_tags__.*:DeprecationWarning:sklearn.utils._tags",
|
| 173 |
+
# sklearn convergence warnings (expected with small test datasets)
|
| 174 |
+
"ignore:.*lbfgs failed to converge.*:sklearn.exceptions.ConvergenceWarning",
|
| 175 |
+
"ignore:.*lbfgs failed to converge.*:UserWarning:sklearn.linear_model._logistic",
|
| 176 |
+
# sklearn scoring warnings (expected when testing edge cases)
|
| 177 |
+
"ignore:.*Scoring failed.*:UserWarning:sklearn.model_selection._validation",
|
| 178 |
+
# sklearn undefined metric warnings (expected with edge case test data)
|
| 179 |
+
"ignore:.*Precision is ill-defined.*:sklearn.exceptions.UndefinedMetricWarning",
|
| 180 |
+
"ignore:.*Precision is ill-defined.*:UserWarning:sklearn.metrics._classification",
|
| 181 |
+
# pytest collection warnings (TestConfig is a dataclass, not a test class)
|
| 182 |
+
"ignore:.*cannot collect test class.*TestConfig.*:pytest.PytestCollectionWarning",
|
| 183 |
+
# General deprecation warnings
|
| 184 |
+
"ignore::DeprecationWarning",
|
| 185 |
+
"ignore::PendingDeprecationWarning",
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
[tool.coverage.run]
|
| 189 |
+
source = ["src"]
|
| 190 |
+
omit = [
|
| 191 |
+
"tests/*",
|
| 192 |
+
"experiments/*",
|
| 193 |
+
"reference_repos/*",
|
| 194 |
+
"**/__pycache__/*",
|
| 195 |
+
".venv/*",
|
| 196 |
+
"**/conftest.py",
|
| 197 |
+
]
|
| 198 |
+
branch = true
|
| 199 |
+
|
| 200 |
+
[tool.coverage.report]
|
| 201 |
+
precision = 2
|
| 202 |
+
exclude_lines = [
|
| 203 |
+
"pragma: no cover",
|
| 204 |
+
"def __repr__",
|
| 205 |
+
"raise AssertionError",
|
| 206 |
+
"raise NotImplementedError",
|
| 207 |
+
"if __name__ == .__main__.:",
|
| 208 |
+
"if TYPE_CHECKING:",
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
[dependency-groups]
|
| 212 |
+
dev = [
|
| 213 |
+
"openpyxl>=3.1.5",
|
| 214 |
+
"types-pyyaml>=6.0.12.20250915",
|
| 215 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces Requirements
|
| 2 |
+
# Minimal dependencies for antibody prediction demo
|
| 3 |
+
|
| 4 |
+
# Core ML
|
| 5 |
+
torch>=2.0.0
|
| 6 |
+
transformers>=4.30.0
|
| 7 |
+
scikit-learn>=1.3.0
|
| 8 |
+
scipy>=1.10.0
|
| 9 |
+
joblib>=1.3.0
|
| 10 |
+
|
| 11 |
+
# Data handling
|
| 12 |
+
pandas>=2.0.0
|
| 13 |
+
numpy>=1.24.0
|
| 14 |
+
|
| 15 |
+
# Configuration
|
| 16 |
+
omegaconf>=2.3.0
|
| 17 |
+
|
| 18 |
+
# Validation
|
| 19 |
+
pydantic>=2.0.0
|
| 20 |
+
|
| 21 |
+
# Gradio UI
|
| 22 |
+
gradio>=5.0.0
|
| 23 |
+
|
| 24 |
+
# Progress bars
|
| 25 |
+
tqdm>=4.65.0
|
| 26 |
+
|
| 27 |
+
# Install local package (antibody_training_esm)
|
| 28 |
+
.
|
src/antibody_training_esm/__init__.py
ADDED
|
File without changes
|
src/antibody_training_esm/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (205 Bytes). View file
|
|
|
src/antibody_training_esm/__pycache__/settings.cpython-312.pyc
ADDED
|
Binary file (9.61 kB). View file
|
|
|
src/antibody_training_esm/cli/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CLI Module
|
| 3 |
+
|
| 4 |
+
Professional command-line interfaces for antibody training pipeline:
|
| 5 |
+
- antibody-train: Model training
|
| 6 |
+
- antibody-test: Model evaluation
|
| 7 |
+
- antibody-preprocess: Dataset preprocessing
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
__all__ = []
|
src/antibody_training_esm/cli/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (440 Bytes). View file
|
|
|
src/antibody_training_esm/cli/__pycache__/app.cpython-312.pyc
ADDED
|
Binary file (7.8 kB). View file
|
|
|
src/antibody_training_esm/cli/__pycache__/predict.cpython-312.pyc
ADDED
|
Binary file (5.09 kB). View file
|
|
|
src/antibody_training_esm/cli/__pycache__/preprocess.cpython-312.pyc
ADDED
|
Binary file (3.6 kB). View file
|
|
|
src/antibody_training_esm/cli/__pycache__/test.cpython-312.pyc
ADDED
|
Binary file (6.49 kB). View file
|
|
|
src/antibody_training_esm/cli/__pycache__/train.cpython-312.pyc
ADDED
|
Binary file (1.29 kB). View file
|
|
|
src/antibody_training_esm/cli/app.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module contains the Gradio app for the antibody non-specificity prediction pipeline.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import platform
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import hydra
|
| 11 |
+
import torch
|
| 12 |
+
from omegaconf import DictConfig
|
| 13 |
+
from pydantic import ValidationError
|
| 14 |
+
|
| 15 |
+
from antibody_training_esm.core.prediction import Predictor
|
| 16 |
+
from antibody_training_esm.models.prediction import PredictionRequest
|
| 17 |
+
|
| 18 |
+
# Configure logging
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def launch_gradio_app(cfg: DictConfig) -> None:
|
| 23 |
+
"""
|
| 24 |
+
Launches the Gradio web UI for antibody prediction.
|
| 25 |
+
|
| 26 |
+
This function sets up a Gradio interface that allows users to input an
|
| 27 |
+
antibody sequence and receive a prediction for its non-specificity.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
cfg: The Hydra configuration object.
|
| 31 |
+
"""
|
| 32 |
+
# Set log level from config
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
level=getattr(logging, cfg.gradio.log_level.upper(), logging.INFO)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Robust Device & Threading Configuration
|
| 38 |
+
# -------------------------------------------------------------------------
|
| 39 |
+
# 1. Determine the optimal device for inference
|
| 40 |
+
# - Prefer CUDA if available (Linux/Windows GPU boxes)
|
| 41 |
+
# - Force CPU on macOS if MPS is detected to avoid Gradio+MPS SegFaults
|
| 42 |
+
# - Default to configured value otherwise
|
| 43 |
+
device = cfg.model.get("device", "cpu")
|
| 44 |
+
|
| 45 |
+
if platform.system() == "Darwin" and device == "mps":
|
| 46 |
+
logger.warning(
|
| 47 |
+
"macOS detected. Forcing CPU for Gradio app stability (MPS workaround)."
|
| 48 |
+
)
|
| 49 |
+
device = "cpu"
|
| 50 |
+
|
| 51 |
+
# 2. Configure Threading to prevent OpenMP SegFaults on macOS
|
| 52 |
+
# - On macOS/CPU, PyTorch's OpenMP runtime can crash inside Gradio threads.
|
| 53 |
+
# - We restrict it to 1 thread to ensure stability.
|
| 54 |
+
# - Linux/CUDA systems remain untouched and can use full parallelism.
|
| 55 |
+
if platform.system() == "Darwin" and device == "cpu":
|
| 56 |
+
logger.warning(
|
| 57 |
+
"macOS/CPU detected. Setting torch.set_num_threads(1) to prevent OpenMP crashes."
|
| 58 |
+
)
|
| 59 |
+
torch.set_num_threads(1)
|
| 60 |
+
|
| 61 |
+
if cfg.classifier.path is None:
|
| 62 |
+
raise ValueError(
|
| 63 |
+
"Classifier path must be specified via command-line override:\n"
|
| 64 |
+
" classifier.path=experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl"
|
| 65 |
+
)
|
| 66 |
+
classifier_path = Path(cfg.classifier.path)
|
| 67 |
+
if not classifier_path.exists():
|
| 68 |
+
raise FileNotFoundError(
|
| 69 |
+
f"Classifier file not found at {classifier_path}. "
|
| 70 |
+
"Train a model (e.g., `make train`) or download a published checkpoint first."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Instantiate the predictor
|
| 74 |
+
config_path = getattr(cfg.classifier, "config_path", None)
|
| 75 |
+
predictor = Predictor(
|
| 76 |
+
model_name=cfg.model.name,
|
| 77 |
+
classifier_path=cfg.classifier.path,
|
| 78 |
+
device=device,
|
| 79 |
+
config_path=config_path,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Warm-up: Run a dummy prediction to load the model into memory eagerly
|
| 83 |
+
try:
|
| 84 |
+
logger.info("Warming up model with dummy prediction...")
|
| 85 |
+
predictor.predict_single("QVQL")
|
| 86 |
+
logger.info("Model warmed up and ready.")
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.warning(f"Model warm-up failed (non-fatal): {e}")
|
| 89 |
+
|
| 90 |
+
def predict_sequence(sequence: str) -> tuple[str, str]:
|
| 91 |
+
"""
|
| 92 |
+
Prediction function for the Gradio interface.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
sequence: The antibody sequence to predict.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
A tuple containing the prediction string and the formatted probability.
|
| 99 |
+
"""
|
| 100 |
+
try:
|
| 101 |
+
# Validate with Pydantic (replaces old validate_input)
|
| 102 |
+
request = PredictionRequest(sequence=sequence)
|
| 103 |
+
|
| 104 |
+
# Log request (observability)
|
| 105 |
+
logger.info(f"Processing: length={len(request.sequence)}")
|
| 106 |
+
|
| 107 |
+
# Predict (returns PydanticResult)
|
| 108 |
+
result = predictor.predict_single(request)
|
| 109 |
+
|
| 110 |
+
# Format probability
|
| 111 |
+
prob_percent = f"{result.probability:.1%}"
|
| 112 |
+
|
| 113 |
+
return result.prediction, prob_percent
|
| 114 |
+
|
| 115 |
+
except ValidationError as e:
|
| 116 |
+
# Extract first error message for user-friendly display
|
| 117 |
+
error_msg = e.errors()[0]["msg"]
|
| 118 |
+
raise gr.Error(error_msg) from e
|
| 119 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 120 |
+
logger.error("GPU OOM during inference")
|
| 121 |
+
raise gr.Error(
|
| 122 |
+
"Server overloaded (GPU OOM). Please try again in a moment."
|
| 123 |
+
) from e
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.exception("Unexpected prediction failure")
|
| 126 |
+
raise gr.Error(f"Prediction failed: {str(e)}") from e
|
| 127 |
+
|
| 128 |
+
# Example sequences (Diverse set)
|
| 129 |
+
examples = [
|
| 130 |
+
[
|
| 131 |
+
"QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYNMHWVRQAPGQGLEWMGGIYPGDSDTRYSPSFQGQVTISADKSISTAYLQWSSLKASDTAMYYCARSTYYGGDWYFNVWGQGTLVTVSS"
|
| 132 |
+
], # Standard VH
|
| 133 |
+
[
|
| 134 |
+
"DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPLTFGGGTKVEIK"
|
| 135 |
+
], # Standard VL
|
| 136 |
+
[
|
| 137 |
+
"EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARSWGQGTLVTVSS"
|
| 138 |
+
], # Short VH (Herceptin-like)
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
# Create the Gradio interface
|
| 142 |
+
iface = gr.Interface(
|
| 143 |
+
fn=predict_sequence,
|
| 144 |
+
inputs=gr.TextArea(
|
| 145 |
+
lines=7,
|
| 146 |
+
max_lines=20,
|
| 147 |
+
max_length=2000,
|
| 148 |
+
label="Antibody Sequence (VH or VL)",
|
| 149 |
+
placeholder="Paste amino acid sequence here (e.g., QVQL...)",
|
| 150 |
+
info="Supported characters: Standard amino acids (ACDEFGHIKLMNPQRSTVWY).",
|
| 151 |
+
show_copy_button=True,
|
| 152 |
+
),
|
| 153 |
+
outputs=[
|
| 154 |
+
gr.Textbox(label="Prediction", show_copy_button=True),
|
| 155 |
+
gr.Textbox(label="Probability of Non-Specificity", show_copy_button=True),
|
| 156 |
+
],
|
| 157 |
+
title="Antibody Non-Specificity Predictor",
|
| 158 |
+
description=(
|
| 159 |
+
"Enter an antibody Variable Heavy (VH) or Variable Light (VL) sequence "
|
| 160 |
+
"to predict its non-specificity (polyreactivity)."
|
| 161 |
+
),
|
| 162 |
+
article=f"Model: {cfg.model.name} | Device: {device}",
|
| 163 |
+
examples=examples,
|
| 164 |
+
cache_examples=True,
|
| 165 |
+
flagging_mode="never",
|
| 166 |
+
analytics_enabled=False,
|
| 167 |
+
submit_btn="Predict Non-Specificity",
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Enable queueing for concurrency management
|
| 171 |
+
"""
|
| 172 |
+
Queue Configuration:
|
| 173 |
+
- concurrency_limit: Based on available VRAM (approx 3GB per ESM-1v inference).
|
| 174 |
+
- max_size: Prevents unbounded queue growth under load.
|
| 175 |
+
"""
|
| 176 |
+
iface.queue(
|
| 177 |
+
default_concurrency_limit=cfg.gradio.queue.concurrency_limit,
|
| 178 |
+
max_size=cfg.gradio.queue.max_size,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Launch the app with hardened settings
|
| 182 |
+
iface.launch(
|
| 183 |
+
server_name=cfg.gradio.server_name,
|
| 184 |
+
server_port=cfg.gradio.server_port,
|
| 185 |
+
share=cfg.gradio.share,
|
| 186 |
+
show_api=False,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@hydra.main(config_path="../conf", config_name="predict", version_base=None)
|
| 191 |
+
def main(cfg: DictConfig) -> None:
|
| 192 |
+
"""Main function to run the Gradio app."""
|
| 193 |
+
launch_gradio_app(cfg)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
main()
|
src/antibody_training_esm/cli/predict.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import cast
|
| 4 |
+
|
| 5 |
+
import hydra
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from omegaconf import DictConfig
|
| 8 |
+
from pydantic import ValidationError
|
| 9 |
+
|
| 10 |
+
from antibody_training_esm.core.config import SEQUENCE_PREVIEW_LENGTH
|
| 11 |
+
from antibody_training_esm.core.prediction import Predictor, run_prediction
|
| 12 |
+
from antibody_training_esm.models.prediction import AssayType, PredictionRequest
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def predict_sequence_cli(
|
| 16 |
+
sequence: str, threshold: float, assay_type: AssayType | None, cfg: DictConfig
|
| 17 |
+
) -> None:
|
| 18 |
+
"""CLI prediction with Pydantic validation."""
|
| 19 |
+
config_path = getattr(cfg.classifier, "config_path", None)
|
| 20 |
+
|
| 21 |
+
# Instantiate predictor (loading model)
|
| 22 |
+
try:
|
| 23 |
+
predictor = Predictor(
|
| 24 |
+
model_name=cfg.model.name,
|
| 25 |
+
classifier_path=cfg.classifier.path,
|
| 26 |
+
config_path=config_path,
|
| 27 |
+
)
|
| 28 |
+
except Exception as e:
|
| 29 |
+
print(f"Error loading model: {e}")
|
| 30 |
+
sys.exit(1)
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
request = PredictionRequest(
|
| 34 |
+
sequence=sequence,
|
| 35 |
+
threshold=threshold,
|
| 36 |
+
assay_type=assay_type,
|
| 37 |
+
)
|
| 38 |
+
result = predictor.predict_single(request)
|
| 39 |
+
|
| 40 |
+
# Print formatted output
|
| 41 |
+
print(
|
| 42 |
+
f"Sequence: {result.sequence[:SEQUENCE_PREVIEW_LENGTH]}..."
|
| 43 |
+
if len(result.sequence) > SEQUENCE_PREVIEW_LENGTH
|
| 44 |
+
else f"Sequence: {result.sequence}"
|
| 45 |
+
)
|
| 46 |
+
print(f"Prediction: {result.prediction}")
|
| 47 |
+
print(f"Probability: {result.probability:.2%}")
|
| 48 |
+
|
| 49 |
+
except ValidationError as e:
|
| 50 |
+
print("❌ Validation Error:")
|
| 51 |
+
for error in e.errors():
|
| 52 |
+
# loc is a tuple, e.g. ('sequence',)
|
| 53 |
+
loc = error["loc"][0] if error["loc"] else "root"
|
| 54 |
+
print(f" - {loc}: {error['msg']}")
|
| 55 |
+
sys.exit(1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@hydra.main(config_path="../conf", config_name="predict", version_base=None)
|
| 59 |
+
def main(cfg: DictConfig) -> None:
|
| 60 |
+
"""Main function to run the prediction CLI."""
|
| 61 |
+
|
| 62 |
+
# Check for single sequence prediction mode
|
| 63 |
+
sequence = getattr(cfg, "sequence", None)
|
| 64 |
+
if sequence:
|
| 65 |
+
threshold = getattr(cfg, "threshold", 0.5)
|
| 66 |
+
assay_type = cast(AssayType | None, getattr(cfg, "assay_type", None))
|
| 67 |
+
predict_sequence_cli(sequence, threshold, assay_type, cfg)
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
# Validate required arguments for batch mode
|
| 71 |
+
if cfg.input_file is None:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
"Input file must be specified via command-line override: `input_file=...`"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
if cfg.classifier.path is None:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
"Classifier path must be specified via command-line override:\n"
|
| 79 |
+
" classifier.path=experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl\n"
|
| 80 |
+
" # OR for production models (.npz):\n"
|
| 81 |
+
" classifier.path=experiments/.../model.npz classifier.config_path=.../model_config.json\n"
|
| 82 |
+
"\nExample usage:\n"
|
| 83 |
+
" uv run antibody-predict \\\n"
|
| 84 |
+
" input_file=data/test.csv \\\n"
|
| 85 |
+
" output_file=predictions.csv \\\n"
|
| 86 |
+
" classifier.path=path/to/model.pkl"
|
| 87 |
+
)
|
| 88 |
+
classifier_path = Path(cfg.classifier.path)
|
| 89 |
+
if not classifier_path.exists():
|
| 90 |
+
raise FileNotFoundError(
|
| 91 |
+
f"Classifier file not found at {classifier_path}. "
|
| 92 |
+
"Train a model (e.g., `make train`) or download a published checkpoint first."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
# Load input data
|
| 97 |
+
input_df = pd.read_csv(cfg.input_file)
|
| 98 |
+
|
| 99 |
+
# Run prediction
|
| 100 |
+
output_df = run_prediction(input_df, cfg)
|
| 101 |
+
|
| 102 |
+
# Save output data
|
| 103 |
+
output_df.to_csv(cfg.output_file, index=False)
|
| 104 |
+
|
| 105 |
+
print(f"Predictions saved to {cfg.output_file}")
|
| 106 |
+
|
| 107 |
+
except FileNotFoundError:
|
| 108 |
+
print(f"Error: Input file not found at {cfg.input_file}")
|
| 109 |
+
exit(1)
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"An error occurred: {e}")
|
| 112 |
+
exit(1)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
main()
|
src/antibody_training_esm/cli/preprocess.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Preprocessing CLI
|
| 3 |
+
|
| 4 |
+
Professional command-line interface for dataset preprocessing.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main() -> int:
|
| 12 |
+
"""
|
| 13 |
+
Main entry point for preprocessing CLI.
|
| 14 |
+
|
| 15 |
+
This CLI does NOT run preprocessing - it only provides guidance on which
|
| 16 |
+
preprocessing scripts to use. Preprocessing is handled by specialized
|
| 17 |
+
scripts that are the Single Source of Truth (SSOT).
|
| 18 |
+
"""
|
| 19 |
+
parser = argparse.ArgumentParser(
|
| 20 |
+
description="Antibody dataset preprocessing guidance",
|
| 21 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 22 |
+
epilog="""
|
| 23 |
+
NOTE: This CLI does NOT run preprocessing. It provides guidance on which
|
| 24 |
+
preprocessing scripts to use. Each dataset has unique requirements and the
|
| 25 |
+
scripts maintain bit-for-bit parity with published methods.
|
| 26 |
+
""",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--dataset",
|
| 31 |
+
"-d",
|
| 32 |
+
type=str,
|
| 33 |
+
required=True,
|
| 34 |
+
choices=["jain", "harvey", "shehata", "boughter"],
|
| 35 |
+
help="Dataset to get preprocessing guidance for",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
print("\n⚠️ The 'antibody-preprocess' CLI is not implemented")
|
| 42 |
+
print(
|
| 43 |
+
"\nDataset preprocessing is handled by specialized scripts, not this CLI."
|
| 44 |
+
)
|
| 45 |
+
print(
|
| 46 |
+
"These scripts are the authoritative source of truth for data transformation."
|
| 47 |
+
)
|
| 48 |
+
print(f"\nFor {args.dataset} dataset, use:")
|
| 49 |
+
|
| 50 |
+
script_paths = {
|
| 51 |
+
"jain": "preprocessing/jain/step2_preprocess_p5e_s2.py",
|
| 52 |
+
"harvey": "preprocessing/harvey/step2_extract_fragments.py",
|
| 53 |
+
"shehata": "preprocessing/shehata/step2_extract_fragments.py",
|
| 54 |
+
"boughter": "preprocessing/boughter/stage2_stage3_annotation_qc.py",
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
script = script_paths.get(args.dataset)
|
| 58 |
+
if script:
|
| 59 |
+
print(f" python {script}")
|
| 60 |
+
|
| 61 |
+
print("\nWhy use scripts instead of this CLI?")
|
| 62 |
+
print(" • Scripts are Single Source of Truth (SSOT) for preprocessing")
|
| 63 |
+
print(
|
| 64 |
+
" • Each dataset has unique requirements (DNA translation, PSR thresholds, etc.)"
|
| 65 |
+
)
|
| 66 |
+
print(" • Scripts maintain bit-for-bit parity with published methods")
|
| 67 |
+
print(" • CLI is for loading preprocessed data, not creating it")
|
| 68 |
+
|
| 69 |
+
print("\nFor more information:")
|
| 70 |
+
print(" • See src/antibody_training_esm/datasets/README.md")
|
| 71 |
+
print(" • See docs/boughter/boughter_data_sources.md (dataset-specific)")
|
| 72 |
+
|
| 73 |
+
return 0
|
| 74 |
+
|
| 75 |
+
except KeyboardInterrupt:
|
| 76 |
+
print("\n❌ Error: Interrupted by user", file=sys.stderr)
|
| 77 |
+
return 1
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"\n❌ Error: {e}", file=sys.stderr)
|
| 80 |
+
return 1
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
sys.exit(main())
|
src/antibody_training_esm/cli/test.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test CLI for Antibody Classification Pipeline
|
| 4 |
+
|
| 5 |
+
Professional command-line interface for testing trained antibody classifiers:
|
| 6 |
+
1. Load trained models from pickle files
|
| 7 |
+
2. Evaluate on test datasets with performance metrics
|
| 8 |
+
3. Generate confusion matrices and comprehensive logging
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
antibody-test --model experiments/checkpoints/antibody_classifier.pkl --data sample_data.csv
|
| 12 |
+
antibody-test --config test_config.yaml
|
| 13 |
+
antibody-test --model m1.pkl m2.pkl --data d1.csv d2.csv
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
from antibody_training_esm.cli.testing.config import (
|
| 20 |
+
TestConfig,
|
| 21 |
+
create_sample_test_config,
|
| 22 |
+
load_config_file,
|
| 23 |
+
)
|
| 24 |
+
from antibody_training_esm.cli.testing.tester import ModelTester
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main() -> int:
|
| 28 |
+
"""Main entry point for antibody-test CLI"""
|
| 29 |
+
parser = argparse.ArgumentParser(
|
| 30 |
+
description="Testing for antibody classification models",
|
| 31 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 32 |
+
epilog="""
|
| 33 |
+
Examples:
|
| 34 |
+
# Test single model on single dataset (auto-detects threshold from dataset name)
|
| 35 |
+
antibody-test --model experiments/checkpoints/antibody_classifier.pkl --data sample_data.csv
|
| 36 |
+
|
| 37 |
+
# Test on PSR dataset with auto-detected threshold (0.5495 for Harvey/Shehata)
|
| 38 |
+
antibody-test --model model.pkl --data data/test/harvey/fragments/VHH_only_harvey.csv
|
| 39 |
+
|
| 40 |
+
# Test multiple models on multiple datasets
|
| 41 |
+
antibody-test --model experiments/checkpoints/model1.pkl experiments/checkpoints/model2.pkl --data dataset1.csv dataset2.csv
|
| 42 |
+
|
| 43 |
+
# Use configuration file
|
| 44 |
+
antibody-test --config test_config.yaml
|
| 45 |
+
|
| 46 |
+
# Override device, batch size, and threshold
|
| 47 |
+
antibody-test --config test_config.yaml --device cuda --batch-size 64 --threshold 0.6
|
| 48 |
+
|
| 49 |
+
# Create sample configuration
|
| 50 |
+
antibody-test --create-config
|
| 51 |
+
""",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--model", nargs="+", help="Path(s) to trained model pickle files"
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument("--data", nargs="+", help="Path(s) to test dataset CSV files")
|
| 58 |
+
parser.add_argument("--config", help="Path to test configuration YAML file")
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--output-dir",
|
| 61 |
+
default="./experiments/benchmarks",
|
| 62 |
+
help="Output directory for results",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--device",
|
| 66 |
+
choices=["cpu", "cuda", "mps"],
|
| 67 |
+
help="Device to use for inference (overrides config)",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--batch-size",
|
| 71 |
+
type=int,
|
| 72 |
+
help="Batch size for embedding extraction (overrides config)",
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--threshold",
|
| 76 |
+
type=float,
|
| 77 |
+
help="Manual decision threshold override (default: auto-detect from dataset name). "
|
| 78 |
+
"Use 0.5 for ELISA datasets (Boughter, Jain) or 0.5495 for PSR datasets (Harvey, Shehata).",
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--sequence-column",
|
| 82 |
+
type=str,
|
| 83 |
+
help="Column name for sequences in dataset (default: 'sequence', overrides config)",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--label-column",
|
| 87 |
+
type=str,
|
| 88 |
+
help="Column name for labels in dataset (default: 'label', overrides config)",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--create-config", action="store_true", help="Create sample configuration file"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
args = parser.parse_args()
|
| 95 |
+
|
| 96 |
+
# Create sample config if requested
|
| 97 |
+
if args.create_config:
|
| 98 |
+
create_sample_test_config()
|
| 99 |
+
return 0
|
| 100 |
+
|
| 101 |
+
# Load configuration
|
| 102 |
+
if args.config:
|
| 103 |
+
config = load_config_file(args.config)
|
| 104 |
+
else:
|
| 105 |
+
if not args.model or not args.data:
|
| 106 |
+
parser.error("Either --config or both --model and --data must be specified")
|
| 107 |
+
|
| 108 |
+
config = TestConfig(
|
| 109 |
+
model_paths=args.model, data_paths=args.data, output_dir=args.output_dir
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Override config with command line arguments
|
| 113 |
+
if args.device:
|
| 114 |
+
config.device = args.device
|
| 115 |
+
if args.batch_size:
|
| 116 |
+
config.batch_size = args.batch_size
|
| 117 |
+
if args.threshold:
|
| 118 |
+
config.threshold = args.threshold
|
| 119 |
+
if args.sequence_column:
|
| 120 |
+
config.sequence_column = args.sequence_column
|
| 121 |
+
if args.label_column:
|
| 122 |
+
config.label_column = args.label_column
|
| 123 |
+
|
| 124 |
+
# Run testing
|
| 125 |
+
try:
|
| 126 |
+
tester = ModelTester(config)
|
| 127 |
+
results = tester.run_comprehensive_test()
|
| 128 |
+
|
| 129 |
+
print(f"\n{'=' * 60}")
|
| 130 |
+
print("TESTING COMPLETED SUCCESSFULLY!")
|
| 131 |
+
print(f"{'=' * 60}")
|
| 132 |
+
print(f"Results saved to: {config.output_dir}")
|
| 133 |
+
|
| 134 |
+
# Print summary
|
| 135 |
+
for dataset_name, dataset_results in results.items():
|
| 136 |
+
print(f"\nDataset: {dataset_name}")
|
| 137 |
+
print("-" * 40)
|
| 138 |
+
for model_name, model_results in dataset_results.items():
|
| 139 |
+
print(f"Model: {model_name}")
|
| 140 |
+
if "test_scores" in model_results:
|
| 141 |
+
for metric, value in model_results["test_scores"].items():
|
| 142 |
+
print(f" {metric}: {value:.4f}")
|
| 143 |
+
|
| 144 |
+
return 0
|
| 145 |
+
|
| 146 |
+
except KeyboardInterrupt:
|
| 147 |
+
print("Error during testing: Interrupted by user", file=sys.stderr)
|
| 148 |
+
return 1
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"Error during testing: {e}", file=sys.stderr)
|
| 151 |
+
return 1
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
sys.exit(main())
|
src/antibody_training_esm/cli/testing/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Test CLI package."""
|
src/antibody_training_esm/cli/testing/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (249 Bytes). View file
|
|
|
src/antibody_training_esm/cli/testing/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (2.62 kB). View file
|
|
|
src/antibody_training_esm/cli/testing/__pycache__/data.cpython-312.pyc
ADDED
|
Binary file (3.47 kB). View file
|
|
|
src/antibody_training_esm/cli/testing/__pycache__/evaluation.cpython-312.pyc
ADDED
|
Binary file (5.38 kB). View file
|
|
|
src/antibody_training_esm/cli/testing/__pycache__/tester.cpython-312.pyc
ADDED
|
Binary file (17.2 kB). View file
|
|
|
src/antibody_training_esm/cli/testing/__pycache__/visualization.cpython-312.pyc
ADDED
|
Binary file (5.21 kB). View file
|
|
|
src/antibody_training_esm/cli/testing/config.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration management for the testing pipeline."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
|
| 7 |
+
from antibody_training_esm.core.config import DEFAULT_BATCH_SIZE
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class TestConfig:
|
| 12 |
+
"""Configuration for testing pipeline"""
|
| 13 |
+
|
| 14 |
+
model_paths: list[str]
|
| 15 |
+
data_paths: list[str]
|
| 16 |
+
sequence_column: str = "sequence" # Column name for sequences in dataset
|
| 17 |
+
label_column: str = "label" # Column name for labels in dataset
|
| 18 |
+
output_dir: str = "./experiments/benchmarks"
|
| 19 |
+
metrics: list[str] | None = None
|
| 20 |
+
save_predictions: bool = True
|
| 21 |
+
batch_size: int = DEFAULT_BATCH_SIZE # Batch size for embedding extraction
|
| 22 |
+
device: str = "mps" # Device to use for inference [cuda, cpu, mps] - MUST match training config
|
| 23 |
+
threshold: float | None = (
|
| 24 |
+
None # Manual threshold override (None = auto-detect from dataset name)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def __post_init__(self) -> None:
|
| 28 |
+
if self.metrics is None:
|
| 29 |
+
self.metrics = [
|
| 30 |
+
"accuracy",
|
| 31 |
+
"precision",
|
| 32 |
+
"recall",
|
| 33 |
+
"f1",
|
| 34 |
+
"roc_auc",
|
| 35 |
+
"pr_auc",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_config_file(config_path: str) -> TestConfig:
|
| 40 |
+
"""Load test configuration from YAML file"""
|
| 41 |
+
with open(config_path) as f:
|
| 42 |
+
config_dict = yaml.safe_load(f)
|
| 43 |
+
|
| 44 |
+
return TestConfig(**config_dict)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def create_sample_test_config() -> None:
|
| 48 |
+
"""Create a sample test configuration file"""
|
| 49 |
+
sample_config = {
|
| 50 |
+
"model_paths": ["./experiments/checkpoints/antibody_classifier.pkl"],
|
| 51 |
+
"data_paths": ["./sample_data.csv"],
|
| 52 |
+
"sequence_column": "sequence",
|
| 53 |
+
"label_column": "label",
|
| 54 |
+
"output_dir": "./experiments/benchmarks",
|
| 55 |
+
"metrics": ["accuracy", "precision", "recall", "f1", "roc_auc", "pr_auc"],
|
| 56 |
+
"save_predictions": True,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
with open("test_config.yaml", "w") as f:
|
| 60 |
+
yaml.dump(sample_config, f, default_flow_style=False)
|
| 61 |
+
|
| 62 |
+
print("Sample test configuration created: test_config.yaml")
|
src/antibody_training_esm/cli/testing/data.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset loading and validation utilities."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from antibody_training_esm.cli.testing.config import TestConfig
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_dataset(data_path: str, config: TestConfig) -> tuple[list[str], list[int]]:
|
| 14 |
+
"""
|
| 15 |
+
Load dataset from CSV file using configured column names.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
data_path: Path to the CSV file.
|
| 19 |
+
config: Test configuration object containing column names.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Tuple of (sequences, labels).
|
| 23 |
+
"""
|
| 24 |
+
logger.info(f"Loading dataset from {data_path}")
|
| 25 |
+
|
| 26 |
+
if not os.path.exists(data_path):
|
| 27 |
+
raise FileNotFoundError(f"Dataset file not found: {data_path}")
|
| 28 |
+
|
| 29 |
+
# Defensive: Handle legacy files with comment headers
|
| 30 |
+
# New files (post-HF cleanup) are standard CSVs without comments
|
| 31 |
+
df = pd.read_csv(data_path, comment="#")
|
| 32 |
+
|
| 33 |
+
sequence_col = config.sequence_column
|
| 34 |
+
label_col = config.label_column
|
| 35 |
+
|
| 36 |
+
if sequence_col not in df.columns:
|
| 37 |
+
raise ValueError(
|
| 38 |
+
f"Sequence column '{sequence_col}' not found in dataset. Available columns: {list(df.columns)}"
|
| 39 |
+
)
|
| 40 |
+
if label_col not in df.columns:
|
| 41 |
+
raise ValueError(
|
| 42 |
+
f"Label column '{label_col}' not found in dataset. Available columns: {list(df.columns)}"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# CRITICAL VALIDATION: Check for NaN labels (P0 bug fix)
|
| 46 |
+
nan_count = df[label_col].isna().sum()
|
| 47 |
+
if nan_count > 0:
|
| 48 |
+
raise ValueError(
|
| 49 |
+
f"CRITICAL: Dataset contains {nan_count} NaN labels! "
|
| 50 |
+
f"This will corrupt evaluation metrics. "
|
| 51 |
+
f"Please use the curated canonical test file (e.g., "
|
| 52 |
+
f"data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv with no NaNs)."
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# For Jain test sets, validate expected size (allow legacy 94 + canonical 86)
|
| 56 |
+
if "jain" in data_path.lower() and "test" in data_path.lower():
|
| 57 |
+
expected_sizes = {94, 86}
|
| 58 |
+
if len(df) not in expected_sizes:
|
| 59 |
+
raise ValueError(
|
| 60 |
+
f"Jain test set has {len(df)} antibodies but expected one of {sorted(expected_sizes)}. "
|
| 61 |
+
f"Using the wrong test set will produce invalid metrics. "
|
| 62 |
+
f"Please use the correct curated file (preferred: "
|
| 63 |
+
f"data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv)."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
sequences = df[sequence_col].tolist()
|
| 67 |
+
labels = df[label_col].tolist()
|
| 68 |
+
|
| 69 |
+
logger.info(
|
| 70 |
+
f"Loaded {len(sequences)} samples from {data_path} (sequence_col='{sequence_col}', label_col='{label_col}')"
|
| 71 |
+
)
|
| 72 |
+
logger.info(f" Label distribution: {pd.Series(labels).value_counts().to_dict()}")
|
| 73 |
+
return sequences, labels
|
src/antibody_training_esm/cli/testing/evaluation.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Metric calculation and model evaluation utilities."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from sklearn.metrics import (
|
| 8 |
+
classification_report,
|
| 9 |
+
confusion_matrix,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from antibody_training_esm.core.classifier import BinaryClassifier
|
| 13 |
+
from antibody_training_esm.models.artifact import EvaluationMetrics
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def detect_assay_type(dataset_name: str) -> str | None:
|
| 19 |
+
"""
|
| 20 |
+
Auto-detect assay type from dataset name for threshold selection
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
dataset_name: Name of the dataset (e.g., "VH_only_jain", "VHH_only_harvey")
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
'ELISA' for ELISA-based datasets (Boughter, Jain)
|
| 27 |
+
'PSR' for PSR-based datasets (Harvey, Shehata)
|
| 28 |
+
None if unable to detect
|
| 29 |
+
|
| 30 |
+
Notes:
|
| 31 |
+
Novo Nordisk (Sakhnini et al. 2025, Section 2.7):
|
| 32 |
+
"Antibodies characterised by the PSR assay appear to be on a different
|
| 33 |
+
non-specificity spectrum than that from the non-specificity ELISA assay."
|
| 34 |
+
|
| 35 |
+
PSR datasets require threshold=0.5495 for optimal performance.
|
| 36 |
+
ELISA datasets use standard threshold=0.5.
|
| 37 |
+
"""
|
| 38 |
+
dataset_lower = dataset_name.lower()
|
| 39 |
+
|
| 40 |
+
# PSR-based datasets (Harvey, Shehata)
|
| 41 |
+
if any(marker in dataset_lower for marker in ["harvey", "shehata"]):
|
| 42 |
+
return "PSR"
|
| 43 |
+
|
| 44 |
+
# ELISA-based datasets (Boughter, Jain)
|
| 45 |
+
if any(marker in dataset_lower for marker in ["boughter", "jain"]):
|
| 46 |
+
return "ELISA"
|
| 47 |
+
|
| 48 |
+
# Unable to detect - will use default threshold
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def evaluate_pretrained(
|
| 53 |
+
model: BinaryClassifier,
|
| 54 |
+
X: np.ndarray,
|
| 55 |
+
y: np.ndarray,
|
| 56 |
+
model_name: str,
|
| 57 |
+
dataset_name: str,
|
| 58 |
+
_metrics_list: list[str] | None = None,
|
| 59 |
+
threshold_override: float | None = None,
|
| 60 |
+
) -> dict[str, Any]:
|
| 61 |
+
"""
|
| 62 |
+
Evaluate pretrained model directly on test set (no retraining)
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
model: The trained BinaryClassifier.
|
| 66 |
+
X: Embeddings (features).
|
| 67 |
+
y: True labels.
|
| 68 |
+
model_name: Name of the model for logging.
|
| 69 |
+
dataset_name: Name of the dataset for logging.
|
| 70 |
+
_metrics_list: List of metrics to calculate (default: all).
|
| 71 |
+
threshold_override: Optional manual threshold.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Dictionary of results including scores, predictions, and reports.
|
| 75 |
+
Contains 'metrics' key with EvaluationMetrics object.
|
| 76 |
+
"""
|
| 77 |
+
logger.info(f"Evaluating pretrained model {model_name} on {dataset_name}")
|
| 78 |
+
|
| 79 |
+
# Determine threshold: manual override > auto-detect > default 0.5
|
| 80 |
+
if threshold_override is not None:
|
| 81 |
+
# Manual override via CLI
|
| 82 |
+
threshold = threshold_override
|
| 83 |
+
logger.info(f"Using manual threshold override: {threshold}")
|
| 84 |
+
else:
|
| 85 |
+
# Auto-detect assay type from dataset name
|
| 86 |
+
assay_type = detect_assay_type(dataset_name)
|
| 87 |
+
if assay_type is not None:
|
| 88 |
+
threshold = model.ASSAY_THRESHOLDS[assay_type]
|
| 89 |
+
logger.info(
|
| 90 |
+
f"Auto-detected assay type: {assay_type} → threshold={threshold} "
|
| 91 |
+
f"(Dataset: {dataset_name})"
|
| 92 |
+
)
|
| 93 |
+
else:
|
| 94 |
+
threshold = 0.5
|
| 95 |
+
logger.warning(
|
| 96 |
+
f"Unable to auto-detect assay type for '{dataset_name}'. "
|
| 97 |
+
f"Using default threshold={threshold}. "
|
| 98 |
+
f"For optimal results, specify --threshold or use standard dataset names."
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Get predictions using the pretrained model with appropriate threshold
|
| 102 |
+
y_pred = model.predict(
|
| 103 |
+
X, threshold=threshold, assay_type=None
|
| 104 |
+
) # threshold already determined
|
| 105 |
+
y_proba = model.predict_proba(X)[:, 1]
|
| 106 |
+
|
| 107 |
+
# Create Pydantic metrics
|
| 108 |
+
eval_metrics = EvaluationMetrics.from_sklearn_metrics(
|
| 109 |
+
y,
|
| 110 |
+
y_pred,
|
| 111 |
+
y_proba.reshape(-1, 1) if y_proba.ndim == 1 else y_proba,
|
| 112 |
+
dataset_name=dataset_name,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Calculate legacy results for compatibility with visualization tools
|
| 116 |
+
results = {
|
| 117 |
+
"metrics": eval_metrics, # Store Pydantic model
|
| 118 |
+
"test_scores": eval_metrics.model_dump(
|
| 119 |
+
exclude={"confusion_matrix", "dataset_name", "n_samples"}
|
| 120 |
+
),
|
| 121 |
+
"predictions": {"y_true": y, "y_pred": y_pred, "y_proba": y_proba},
|
| 122 |
+
"confusion_matrix": confusion_matrix(y, y_pred),
|
| 123 |
+
"classification_report": classification_report(y, y_pred, output_dict=True),
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
# Log results
|
| 127 |
+
logger.info(f"Test results for {model_name} on {dataset_name}:")
|
| 128 |
+
logger.info(f" Accuracy: {eval_metrics.accuracy:.4f}")
|
| 129 |
+
if eval_metrics.f1 is not None:
|
| 130 |
+
logger.info(f" F1: {eval_metrics.f1:.4f}")
|
| 131 |
+
if eval_metrics.roc_auc is not None:
|
| 132 |
+
logger.info(f" ROC-AUC: {eval_metrics.roc_auc:.4f}")
|
| 133 |
+
|
| 134 |
+
return results
|
src/antibody_training_esm/cli/testing/tester.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"Model orchestration logic."
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import pickle # nosec B403
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from antibody_training_esm.cli.testing.config import TestConfig
|
| 15 |
+
from antibody_training_esm.cli.testing.data import load_dataset
|
| 16 |
+
from antibody_training_esm.cli.testing.evaluation import evaluate_pretrained
|
| 17 |
+
from antibody_training_esm.cli.testing.visualization import (
|
| 18 |
+
plot_confusion_matrix,
|
| 19 |
+
save_detailed_results,
|
| 20 |
+
)
|
| 21 |
+
from antibody_training_esm.core.classifier import BinaryClassifier
|
| 22 |
+
from antibody_training_esm.core.config import DEFAULT_BATCH_SIZE
|
| 23 |
+
from antibody_training_esm.core.directory_utils import (
|
| 24 |
+
extract_classifier_shortname,
|
| 25 |
+
extract_model_shortname,
|
| 26 |
+
get_hierarchical_test_results_dir,
|
| 27 |
+
)
|
| 28 |
+
from antibody_training_esm.core.embeddings import ESMEmbeddingExtractor
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ModelTester:
|
| 32 |
+
"""Model testing orchestrator"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, config: TestConfig):
|
| 35 |
+
self.config = config
|
| 36 |
+
self.logger = self._setup_logging()
|
| 37 |
+
self.results: dict[str, Any] = {}
|
| 38 |
+
self.cached_embedding_files: list[str] = [] # Track cached files for cleanup
|
| 39 |
+
|
| 40 |
+
# Create output directory
|
| 41 |
+
os.makedirs(config.output_dir, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
def _setup_logging(self) -> logging.Logger:
|
| 44 |
+
"""Setup logging configuration"""
|
| 45 |
+
# Create output directory if it doesn't exist
|
| 46 |
+
os.makedirs(self.config.output_dir, exist_ok=True)
|
| 47 |
+
|
| 48 |
+
log_file = os.path.join(
|
| 49 |
+
self.config.output_dir,
|
| 50 |
+
f"test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
logging.basicConfig(
|
| 54 |
+
level=logging.INFO,
|
| 55 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 56 |
+
handlers=[logging.FileHandler(log_file), logging.StreamHandler()],
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return logging.getLogger(__name__)
|
| 60 |
+
|
| 61 |
+
def load_model(self, model_path: str) -> BinaryClassifier:
|
| 62 |
+
"""Load trained model from pickle file"""
|
| 63 |
+
self.logger.info(f"Loading model from {model_path}")
|
| 64 |
+
|
| 65 |
+
if not os.path.exists(model_path):
|
| 66 |
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 67 |
+
|
| 68 |
+
with open(model_path, "rb") as f:
|
| 69 |
+
model = pickle.load(f) # nosec B301
|
| 70 |
+
|
| 71 |
+
if not isinstance(model, BinaryClassifier):
|
| 72 |
+
raise ValueError(f"Expected BinaryClassifier, got {type(model)}")
|
| 73 |
+
|
| 74 |
+
# Update device if different from config
|
| 75 |
+
if (
|
| 76 |
+
hasattr(model, "embedding_extractor")
|
| 77 |
+
and model.embedding_extractor.device != self.config.device
|
| 78 |
+
):
|
| 79 |
+
self.logger.warning(
|
| 80 |
+
f"Device mismatch: model trained on {model.embedding_extractor.device}, "
|
| 81 |
+
f"test config specifies {self.config.device}. Recreating extractor..."
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# CRITICAL: Explicit cleanup to prevent semaphore leaks (P0 bug fix)
|
| 85 |
+
old_device = str(model.embedding_extractor.device)
|
| 86 |
+
old_extractor = model.embedding_extractor
|
| 87 |
+
|
| 88 |
+
# Delete old extractor before creating new one
|
| 89 |
+
del model.embedding_extractor
|
| 90 |
+
del old_extractor
|
| 91 |
+
|
| 92 |
+
# Clear device-specific GPU cache
|
| 93 |
+
if old_device.startswith("cuda"):
|
| 94 |
+
torch.cuda.empty_cache()
|
| 95 |
+
elif old_device.startswith("mps"):
|
| 96 |
+
torch.mps.empty_cache()
|
| 97 |
+
|
| 98 |
+
self.logger.info(f"Cleaned up old extractor on {old_device}")
|
| 99 |
+
|
| 100 |
+
# NOW create new extractor (no leak)
|
| 101 |
+
batch_size = getattr(model, "batch_size", DEFAULT_BATCH_SIZE)
|
| 102 |
+
revision = getattr(model, "revision", "main")
|
| 103 |
+
model.embedding_extractor = ESMEmbeddingExtractor(
|
| 104 |
+
model.model_name, self.config.device, batch_size, revision=revision
|
| 105 |
+
)
|
| 106 |
+
model.device = self.config.device
|
| 107 |
+
|
| 108 |
+
self.logger.info(f"Created new extractor on {self.config.device}")
|
| 109 |
+
|
| 110 |
+
# Update batch_size if different from config
|
| 111 |
+
if (
|
| 112 |
+
hasattr(model, "embedding_extractor")
|
| 113 |
+
and model.embedding_extractor.batch_size != self.config.batch_size
|
| 114 |
+
):
|
| 115 |
+
self.logger.info(
|
| 116 |
+
f"Updating batch_size from {model.embedding_extractor.batch_size} to {self.config.batch_size}"
|
| 117 |
+
)
|
| 118 |
+
model.embedding_extractor.batch_size = self.config.batch_size
|
| 119 |
+
|
| 120 |
+
self.logger.info(
|
| 121 |
+
f"Model loaded successfully: {model_path} on device: {model.embedding_extractor.device}"
|
| 122 |
+
)
|
| 123 |
+
return model
|
| 124 |
+
|
| 125 |
+
def embed_sequences(
|
| 126 |
+
self,
|
| 127 |
+
sequences: list[str],
|
| 128 |
+
model: BinaryClassifier,
|
| 129 |
+
dataset_name: str,
|
| 130 |
+
output_dir: str,
|
| 131 |
+
) -> np.ndarray:
|
| 132 |
+
"""Extract embeddings for sequences using the model's embedding extractor"""
|
| 133 |
+
# Ensure output directory exists before file I/O
|
| 134 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 135 |
+
|
| 136 |
+
cache_file = os.path.join(output_dir, f"{dataset_name}_test_embeddings.pkl")
|
| 137 |
+
|
| 138 |
+
# Track this file for cleanup
|
| 139 |
+
if cache_file not in self.cached_embedding_files:
|
| 140 |
+
self.cached_embedding_files.append(cache_file)
|
| 141 |
+
|
| 142 |
+
# Try to load from cache
|
| 143 |
+
if os.path.exists(cache_file):
|
| 144 |
+
try:
|
| 145 |
+
self.logger.info(f"Loading cached embeddings from {cache_file}")
|
| 146 |
+
with open(cache_file, "rb") as f:
|
| 147 |
+
embeddings: np.ndarray = pickle.load(f) # nosec B301
|
| 148 |
+
|
| 149 |
+
# Validate shape and type
|
| 150 |
+
if not isinstance(embeddings, np.ndarray):
|
| 151 |
+
raise ValueError(f"Invalid cache data type: {type(embeddings)}")
|
| 152 |
+
if embeddings.ndim != 2:
|
| 153 |
+
raise ValueError(f"Invalid embedding shape: {embeddings.shape}")
|
| 154 |
+
|
| 155 |
+
if len(embeddings) == len(sequences):
|
| 156 |
+
self.logger.info(f"Loaded {len(embeddings)} cached embeddings")
|
| 157 |
+
return embeddings
|
| 158 |
+
else:
|
| 159 |
+
self.logger.warning(
|
| 160 |
+
"Cached embeddings size mismatch, recomputing..."
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
except (pickle.UnpicklingError, EOFError, ValueError, AttributeError) as e:
|
| 164 |
+
self.logger.warning(
|
| 165 |
+
f"Failed to load cached embeddings from {cache_file}: {e}. "
|
| 166 |
+
"Recomputing embeddings..."
|
| 167 |
+
)
|
| 168 |
+
# Fall through to recomputation below
|
| 169 |
+
|
| 170 |
+
# Extract embeddings
|
| 171 |
+
self.logger.info(f"Extracting embeddings for {len(sequences)} sequences...")
|
| 172 |
+
embeddings = model.embedding_extractor.extract_batch_embeddings(sequences)
|
| 173 |
+
|
| 174 |
+
# Cache embeddings
|
| 175 |
+
with open(cache_file, "wb") as f:
|
| 176 |
+
pickle.dump(embeddings, f)
|
| 177 |
+
self.logger.info(f"Embeddings cached to {cache_file}")
|
| 178 |
+
|
| 179 |
+
return embeddings
|
| 180 |
+
|
| 181 |
+
def cleanup_cached_embeddings(self) -> None:
|
| 182 |
+
"""Delete cached embedding files"""
|
| 183 |
+
self.logger.info("Cleaning up cached embedding files...")
|
| 184 |
+
for cache_file in self.cached_embedding_files:
|
| 185 |
+
if os.path.exists(cache_file):
|
| 186 |
+
try:
|
| 187 |
+
os.remove(cache_file)
|
| 188 |
+
self.logger.info(f"Deleted cached embeddings: {cache_file}")
|
| 189 |
+
except Exception as e:
|
| 190 |
+
self.logger.warning(f"Failed to delete {cache_file}: {e}")
|
| 191 |
+
|
| 192 |
+
def _compute_output_directory(
|
| 193 |
+
self,
|
| 194 |
+
model_path: str | None,
|
| 195 |
+
dataset_name: str,
|
| 196 |
+
) -> str:
|
| 197 |
+
"""Compute output directory (hierarchical if model config available, else flat)."""
|
| 198 |
+
if model_path is None:
|
| 199 |
+
self.logger.warning("No model path provided, using flat output structure")
|
| 200 |
+
return self.config.output_dir
|
| 201 |
+
|
| 202 |
+
# Try to load model config JSON
|
| 203 |
+
model_config_path = (
|
| 204 |
+
Path(model_path)
|
| 205 |
+
.with_suffix("")
|
| 206 |
+
.with_name(Path(model_path).stem + "_config.json")
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
if not model_config_path.exists():
|
| 210 |
+
self.logger.info(
|
| 211 |
+
f"Model config not found at {model_config_path}, using flat output structure"
|
| 212 |
+
)
|
| 213 |
+
return self.config.output_dir
|
| 214 |
+
|
| 215 |
+
try:
|
| 216 |
+
with open(model_config_path) as f:
|
| 217 |
+
model_config = json.load(f)
|
| 218 |
+
|
| 219 |
+
model_name = model_config.get("model_name") or model_config.get(
|
| 220 |
+
"esm_model", ""
|
| 221 |
+
)
|
| 222 |
+
if not model_name:
|
| 223 |
+
raise ValueError("Model config missing 'model_name' or 'esm_model'")
|
| 224 |
+
|
| 225 |
+
classifier_config = model_config.get("classifier", {})
|
| 226 |
+
|
| 227 |
+
# Use shared utility for hierarchical path generation
|
| 228 |
+
hierarchical_path = get_hierarchical_test_results_dir(
|
| 229 |
+
base_dir=self.config.output_dir,
|
| 230 |
+
model_name=model_name,
|
| 231 |
+
classifier_config=classifier_config,
|
| 232 |
+
dataset_name=dataset_name,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Extract shortnames for logging
|
| 236 |
+
model_short = extract_model_shortname(model_name)
|
| 237 |
+
classifier_short = extract_classifier_shortname(classifier_config)
|
| 238 |
+
|
| 239 |
+
self.logger.info(
|
| 240 |
+
f"Using hierarchical output: {hierarchical_path} "
|
| 241 |
+
f"(model={model_short}, classifier={classifier_short})"
|
| 242 |
+
)
|
| 243 |
+
return str(hierarchical_path)
|
| 244 |
+
|
| 245 |
+
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
| 246 |
+
self.logger.warning(
|
| 247 |
+
f"Could not determine hierarchical path from model config: {e}. "
|
| 248 |
+
"Using flat structure."
|
| 249 |
+
)
|
| 250 |
+
return self.config.output_dir
|
| 251 |
+
|
| 252 |
+
def run_comprehensive_test(self) -> dict[str, dict[str, Any]]:
|
| 253 |
+
"""Run testing pipeline"""
|
| 254 |
+
self.logger.info("Starting model testing")
|
| 255 |
+
self.logger.info(f"Models to test: {self.config.model_paths}")
|
| 256 |
+
self.logger.info(f"Datasets to test: {self.config.data_paths}")
|
| 257 |
+
|
| 258 |
+
all_results = {}
|
| 259 |
+
failed_datasets = []
|
| 260 |
+
failed_models = []
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
# Test each dataset
|
| 264 |
+
for data_path in self.config.data_paths:
|
| 265 |
+
dataset_name = Path(data_path).stem
|
| 266 |
+
self.logger.info(f"\n{'=' * 60}")
|
| 267 |
+
self.logger.info(f"Testing on dataset: {dataset_name}")
|
| 268 |
+
self.logger.info(f"{'=' * 60}")
|
| 269 |
+
|
| 270 |
+
# Load dataset
|
| 271 |
+
try:
|
| 272 |
+
sequences, labels_list = load_dataset(data_path, self.config)
|
| 273 |
+
labels: np.ndarray = np.array(labels_list)
|
| 274 |
+
except Exception as e:
|
| 275 |
+
self.logger.error(f"Failed to load dataset {data_path}: {e}")
|
| 276 |
+
failed_datasets.append((dataset_name, str(e)))
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
dataset_results = {}
|
| 280 |
+
|
| 281 |
+
# Test each model
|
| 282 |
+
for model_path in self.config.model_paths:
|
| 283 |
+
model_name = Path(model_path).stem
|
| 284 |
+
self.logger.info(f"\nTesting model: {model_name}")
|
| 285 |
+
|
| 286 |
+
output_dir_for_dataset = self._compute_output_directory(
|
| 287 |
+
model_path, dataset_name
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
# Load model
|
| 292 |
+
model = self.load_model(model_path)
|
| 293 |
+
|
| 294 |
+
# Extract embeddings
|
| 295 |
+
X_embedded = self.embed_sequences(
|
| 296 |
+
sequences,
|
| 297 |
+
model,
|
| 298 |
+
f"{dataset_name}_{model_name}",
|
| 299 |
+
output_dir_for_dataset,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Evaluation (delegated to evaluation module)
|
| 303 |
+
test_results = evaluate_pretrained(
|
| 304 |
+
model,
|
| 305 |
+
X_embedded,
|
| 306 |
+
labels,
|
| 307 |
+
model_name,
|
| 308 |
+
dataset_name,
|
| 309 |
+
self.config.metrics,
|
| 310 |
+
self.config.threshold,
|
| 311 |
+
)
|
| 312 |
+
dataset_results[model_name] = test_results
|
| 313 |
+
|
| 314 |
+
# Visualization (delegated to visualization module)
|
| 315 |
+
single_model_results = {model_name: test_results}
|
| 316 |
+
plot_confusion_matrix(
|
| 317 |
+
single_model_results,
|
| 318 |
+
dataset_name,
|
| 319 |
+
output_dir=output_dir_for_dataset,
|
| 320 |
+
)
|
| 321 |
+
save_detailed_results(
|
| 322 |
+
single_model_results,
|
| 323 |
+
dataset_name,
|
| 324 |
+
self.config.__dict__,
|
| 325 |
+
output_dir=output_dir_for_dataset,
|
| 326 |
+
save_predictions=self.config.save_predictions,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
except Exception as e:
|
| 330 |
+
self.logger.error(f"Failed to test model {model_path}: {e}")
|
| 331 |
+
failed_models.append((f"{dataset_name}_{model_name}", str(e)))
|
| 332 |
+
continue
|
| 333 |
+
|
| 334 |
+
# Generate aggregated multi-model report
|
| 335 |
+
if dataset_results:
|
| 336 |
+
aggregated_output_dir = self.config.output_dir
|
| 337 |
+
self.logger.info(
|
| 338 |
+
f"Generating aggregated multi-model report for {dataset_name} "
|
| 339 |
+
f"in {aggregated_output_dir}"
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
plot_confusion_matrix(
|
| 343 |
+
dataset_results,
|
| 344 |
+
dataset_name,
|
| 345 |
+
output_dir=aggregated_output_dir,
|
| 346 |
+
)
|
| 347 |
+
save_detailed_results(
|
| 348 |
+
dataset_results,
|
| 349 |
+
dataset_name,
|
| 350 |
+
self.config.__dict__,
|
| 351 |
+
output_dir=aggregated_output_dir,
|
| 352 |
+
save_predictions=self.config.save_predictions,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
all_results[dataset_name] = dataset_results
|
| 356 |
+
|
| 357 |
+
# Check if all tests failed
|
| 358 |
+
if not all_results:
|
| 359 |
+
error_msg = "All tests failed:\n"
|
| 360 |
+
if failed_datasets:
|
| 361 |
+
error_msg += (
|
| 362 |
+
f" Failed datasets: {[name for name, _ in failed_datasets]}\n"
|
| 363 |
+
)
|
| 364 |
+
if failed_models:
|
| 365 |
+
error_msg += (
|
| 366 |
+
f" Failed models: {[name for name, _ in failed_models]}\n"
|
| 367 |
+
)
|
| 368 |
+
raise RuntimeError(error_msg + "No successful test results to report.")
|
| 369 |
+
|
| 370 |
+
if failed_datasets or failed_models:
|
| 371 |
+
self.logger.warning(
|
| 372 |
+
f"\nSome tests failed (datasets: {len(failed_datasets)}, "
|
| 373 |
+
f"models: {len(failed_models)}). Check logs for details."
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
self.results = all_results
|
| 377 |
+
self.logger.info(
|
| 378 |
+
f"\nTesting completed. Results saved to: {self.config.output_dir}"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
finally:
|
| 382 |
+
self.cleanup_cached_embeddings()
|
| 383 |
+
|
| 384 |
+
return all_results
|
src/antibody_training_esm/cli/testing/visualization.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Plotting and result serialization utilities."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
import yaml
|
| 12 |
+
|
| 13 |
+
# Configure matplotlib
|
| 14 |
+
plt.style.use("seaborn-v0_8" if "seaborn-v0_8" in plt.style.available else "default")
|
| 15 |
+
sns.set_palette("husl")
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def plot_confusion_matrix(
|
| 21 |
+
results: dict[str, dict[str, Any]],
|
| 22 |
+
dataset_name: str,
|
| 23 |
+
output_dir: str,
|
| 24 |
+
) -> None:
|
| 25 |
+
"""
|
| 26 |
+
Create confusion matrix visualization (individual files per model).
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
results: Dictionary mapping model names to result dictionaries.
|
| 30 |
+
dataset_name: Name of the dataset.
|
| 31 |
+
output_dir: Directory to save plots.
|
| 32 |
+
"""
|
| 33 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
logger.info(f"Creating confusion matrices for {dataset_name} in {output_dir}")
|
| 36 |
+
|
| 37 |
+
# Create individual confusion matrix for each model to prevent overrides
|
| 38 |
+
for model_name, model_results in results.items():
|
| 39 |
+
if "confusion_matrix" not in model_results:
|
| 40 |
+
logger.warning(f"No confusion matrix found for {model_name}, skipping plot")
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
|
| 44 |
+
cm = model_results["confusion_matrix"]
|
| 45 |
+
sns.heatmap(
|
| 46 |
+
cm,
|
| 47 |
+
annot=True,
|
| 48 |
+
fmt="d",
|
| 49 |
+
cmap="Blues",
|
| 50 |
+
xticklabels=["Negative", "Positive"],
|
| 51 |
+
yticklabels=["Negative", "Positive"],
|
| 52 |
+
ax=ax,
|
| 53 |
+
)
|
| 54 |
+
ax.set_title(f"Confusion Matrix - {model_name} on {dataset_name}")
|
| 55 |
+
ax.set_ylabel("True Label")
|
| 56 |
+
ax.set_xlabel("Predicted Label")
|
| 57 |
+
|
| 58 |
+
plt.tight_layout()
|
| 59 |
+
|
| 60 |
+
# Save plot with model name to prevent overrides when testing multiple backbones
|
| 61 |
+
plot_file = os.path.join(
|
| 62 |
+
output_dir,
|
| 63 |
+
f"confusion_matrix_{model_name}_{dataset_name}.png",
|
| 64 |
+
)
|
| 65 |
+
plt.savefig(plot_file, dpi=300, bbox_inches="tight")
|
| 66 |
+
plt.close()
|
| 67 |
+
|
| 68 |
+
logger.info(f"Confusion matrix saved to {plot_file}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def save_detailed_results(
|
| 72 |
+
results: dict[str, dict[str, Any]],
|
| 73 |
+
dataset_name: str,
|
| 74 |
+
config_dict: dict[str, Any],
|
| 75 |
+
output_dir: str,
|
| 76 |
+
save_predictions: bool = True,
|
| 77 |
+
) -> None:
|
| 78 |
+
"""
|
| 79 |
+
Save detailed results to files (individual files per model).
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
results: Dictionary mapping model names to result dictionaries.
|
| 83 |
+
dataset_name: Name of the dataset.
|
| 84 |
+
config_dict: Configuration dictionary to embed in YAML.
|
| 85 |
+
output_dir: Directory to save results.
|
| 86 |
+
save_predictions: Whether to save prediction CSVs.
|
| 87 |
+
"""
|
| 88 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 89 |
+
|
| 90 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 91 |
+
|
| 92 |
+
# Save individual YAML for each model to prevent overrides
|
| 93 |
+
for model_name, model_results in results.items():
|
| 94 |
+
results_file = os.path.join(
|
| 95 |
+
output_dir,
|
| 96 |
+
f"detailed_results_{model_name}_{dataset_name}_{timestamp}.yaml",
|
| 97 |
+
)
|
| 98 |
+
with open(results_file, "w") as f:
|
| 99 |
+
yaml.dump(
|
| 100 |
+
{
|
| 101 |
+
"dataset": dataset_name,
|
| 102 |
+
"model": model_name,
|
| 103 |
+
"config": config_dict,
|
| 104 |
+
"results": model_results,
|
| 105 |
+
},
|
| 106 |
+
f,
|
| 107 |
+
default_flow_style=False,
|
| 108 |
+
)
|
| 109 |
+
logger.info(f"Detailed results saved to {results_file}")
|
| 110 |
+
|
| 111 |
+
# Save predictions if requested
|
| 112 |
+
if save_predictions:
|
| 113 |
+
for model_name, model_results in results.items():
|
| 114 |
+
if "predictions" in model_results:
|
| 115 |
+
pred_file = os.path.join(
|
| 116 |
+
output_dir,
|
| 117 |
+
f"predictions_{model_name}_{dataset_name}_{timestamp}.csv",
|
| 118 |
+
)
|
| 119 |
+
pred_df = pd.DataFrame(
|
| 120 |
+
{
|
| 121 |
+
"y_true": model_results["predictions"]["y_true"],
|
| 122 |
+
"y_pred": model_results["predictions"]["y_pred"],
|
| 123 |
+
"y_proba": model_results["predictions"]["y_proba"],
|
| 124 |
+
}
|
| 125 |
+
)
|
| 126 |
+
pred_df.to_csv(pred_file, index=False)
|
| 127 |
+
logger.info(f"Predictions saved to {pred_file}")
|
src/antibody_training_esm/cli/train.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training CLI - Hydra Entry Point
|
| 3 |
+
|
| 4 |
+
Professional command-line interface for antibody model training.
|
| 5 |
+
Uses Hydra for configuration management and supports dynamic overrides.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# Default config
|
| 9 |
+
antibody-train
|
| 10 |
+
|
| 11 |
+
# With overrides
|
| 12 |
+
antibody-train hardware.device=cuda training.batch_size=16
|
| 13 |
+
|
| 14 |
+
# Multi-run sweep
|
| 15 |
+
antibody-train --multirun classifier.C=0.1,1.0,10.0
|
| 16 |
+
|
| 17 |
+
# Help
|
| 18 |
+
antibody-train --help
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from antibody_training_esm.core.trainer import main as hydra_main
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main() -> None:
|
| 25 |
+
"""
|
| 26 |
+
Main entry point for training CLI
|
| 27 |
+
|
| 28 |
+
Delegates to Hydra-decorated main() in core.trainer.
|
| 29 |
+
This provides automatic config composition, override support,
|
| 30 |
+
and multi-run sweeps.
|
| 31 |
+
|
| 32 |
+
Note:
|
| 33 |
+
This function does not return an exit code (Hydra handles that).
|
| 34 |
+
Use try/except at a higher level if you need custom error handling.
|
| 35 |
+
"""
|
| 36 |
+
# Delegate to Hydra entry point
|
| 37 |
+
# Hydra automatically parses sys.argv and handles all CLI logic
|
| 38 |
+
hydra_main()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
main()
|
src/antibody_training_esm/conf/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hydra configuration package
|
| 3 |
+
|
| 4 |
+
Contains YAML configs and structured config schemas.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
# Import config_schema to execute ConfigStore registrations
|
| 8 |
+
# This MUST run at import time for structured configs to work
|
| 9 |
+
from . import config_schema # noqa: F401
|
src/antibody_training_esm/conf/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (356 Bytes). View file
|
|
|
src/antibody_training_esm/conf/__pycache__/config_schema.cpython-312.pyc
ADDED
|
Binary file (5.3 kB). View file
|
|
|
src/antibody_training_esm/conf/classifier/logreg.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: logistic_regression
|
| 2 |
+
C: 1.0
|
| 3 |
+
penalty: l2
|
| 4 |
+
solver: lbfgs
|
| 5 |
+
max_iter: 1000
|
| 6 |
+
random_state: ${training.random_state}
|
| 7 |
+
class_weight: null
|
| 8 |
+
cv_folds: 10
|
| 9 |
+
stratify: true
|
| 10 |
+
path: null
|
| 11 |
+
# Optional path to the JSON config file (for .npz models)
|
| 12 |
+
config_path: null
|
src/antibody_training_esm/conf/classifier/xgboost.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: xgboost
|
| 2 |
+
n_estimators: 100
|
| 3 |
+
max_depth: 6
|
| 4 |
+
learning_rate: 0.3
|
| 5 |
+
subsample: 1.0
|
| 6 |
+
colsample_bytree: 1.0
|
| 7 |
+
reg_alpha: 0.0
|
| 8 |
+
reg_lambda: 1.0
|
| 9 |
+
random_state: ${training.random_state}
|
| 10 |
+
objective: binary:logistic
|
| 11 |
+
cv_folds: 10
|
| 12 |
+
stratify: true
|
| 13 |
+
path: null
|
| 14 |
+
config_path: null
|
src/antibody_training_esm/conf/config.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model: esm1v
|
| 3 |
+
- classifier: logreg
|
| 4 |
+
- data: boughter_jain
|
| 5 |
+
- hardware: default
|
| 6 |
+
- hydra: default
|
| 7 |
+
- _self_
|
| 8 |
+
|
| 9 |
+
# Training settings (matches current trainer.py requirements)
|
| 10 |
+
training:
|
| 11 |
+
# Cross-validation
|
| 12 |
+
n_splits: 10
|
| 13 |
+
random_state: 42
|
| 14 |
+
stratify: true
|
| 15 |
+
|
| 16 |
+
# Evaluation metrics (list of metrics to compute)
|
| 17 |
+
metrics: [accuracy, precision, recall, f1, roc_auc]
|
| 18 |
+
|
| 19 |
+
# Model saving
|
| 20 |
+
save_model: true
|
| 21 |
+
model_name: boughter_vh_esm1v_logreg
|
| 22 |
+
model_save_dir: ./experiments/checkpoints
|
| 23 |
+
|
| 24 |
+
# Logging (Hydra-aware: relative to Hydra output dir, or logs/ in legacy mode)
|
| 25 |
+
log_level: INFO
|
| 26 |
+
log_file: training.log
|
| 27 |
+
|
| 28 |
+
# Performance optimization
|
| 29 |
+
batch_size: 8
|
| 30 |
+
num_workers: 4
|
| 31 |
+
|
| 32 |
+
# Experiment metadata (Hydra manages output dirs)
|
| 33 |
+
experiment:
|
| 34 |
+
name: novo_replication
|
| 35 |
+
description: "Train ESM-1v VH-based LogisticReg on Boughter, test on Jain"
|
| 36 |
+
tags: [baseline, esm1v, logreg]
|
src/antibody_training_esm/conf/config_schema.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Structured configuration schemas for Hydra
|
| 3 |
+
|
| 4 |
+
Type-safe configuration using dataclasses with full field coverage
|
| 5 |
+
validated against current trainer.py requirements.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
|
| 10 |
+
# ConfigStore import removed - no longer needed since registrations are commented out
|
| 11 |
+
# from hydra.core.config_store import ConfigStore
|
| 12 |
+
from omegaconf import MISSING
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ModelConfig:
|
| 17 |
+
"""ESM model configuration (matches current model config structure)"""
|
| 18 |
+
|
| 19 |
+
name: str = "facebook/esm1v_t33_650M_UR90S_1"
|
| 20 |
+
revision: str = "main"
|
| 21 |
+
device: str = MISSING # Provided by YAML interpolation ${hardware.device}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class ClassifierConfig:
|
| 26 |
+
"""Classifier head configuration (matches current classifier config)"""
|
| 27 |
+
|
| 28 |
+
type: str = "logistic_regression"
|
| 29 |
+
C: float = 1.0
|
| 30 |
+
penalty: str = "l2"
|
| 31 |
+
solver: str = "lbfgs"
|
| 32 |
+
max_iter: int = 1000
|
| 33 |
+
random_state: int = (
|
| 34 |
+
MISSING # Provided by YAML interpolation ${training.random_state}
|
| 35 |
+
)
|
| 36 |
+
class_weight: str | None = None
|
| 37 |
+
cv_folds: int = 10
|
| 38 |
+
stratify: bool = True
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class DataConfig:
|
| 43 |
+
"""Dataset configuration (ALL fields used by loaders.py + trainer.py)"""
|
| 44 |
+
|
| 45 |
+
# REQUIRED by loaders.py
|
| 46 |
+
source: str = "local"
|
| 47 |
+
train_file: str = MISSING # Required
|
| 48 |
+
test_file: str = MISSING # Required
|
| 49 |
+
sequence_column: str = "sequence"
|
| 50 |
+
label_column: str = "label"
|
| 51 |
+
|
| 52 |
+
# REQUIRED by trainer.py
|
| 53 |
+
embeddings_cache_dir: str = "./experiments/cache"
|
| 54 |
+
|
| 55 |
+
# Optional fields
|
| 56 |
+
dataset_name: str = "boughter_vh"
|
| 57 |
+
max_sequence_length: int = 1024
|
| 58 |
+
save_embeddings: bool = True
|
| 59 |
+
|
| 60 |
+
# Fragment metadata (testing only)
|
| 61 |
+
train_fragment: str = "VH"
|
| 62 |
+
test_fragment: str = "VH"
|
| 63 |
+
test_assay: str = "ELISA"
|
| 64 |
+
test_threshold: float = 0.5
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class TrainingConfig:
|
| 69 |
+
"""Training hyperparameters (ALL fields used by trainer.py)"""
|
| 70 |
+
|
| 71 |
+
# Cross-validation
|
| 72 |
+
n_splits: int = 10
|
| 73 |
+
random_state: int = 42
|
| 74 |
+
stratify: bool = True
|
| 75 |
+
|
| 76 |
+
# Evaluation metrics
|
| 77 |
+
metrics: list[str] = field(
|
| 78 |
+
default_factory=lambda: ["accuracy", "precision", "recall", "f1", "roc_auc"]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Model saving
|
| 82 |
+
save_model: bool = True
|
| 83 |
+
model_name: str = "boughter_vh_esm1v_logreg"
|
| 84 |
+
model_save_dir: str = "./experiments/checkpoints"
|
| 85 |
+
|
| 86 |
+
# Logging (Hydra-aware: relative to Hydra output dir, or logs/ in legacy mode)
|
| 87 |
+
log_level: str = "INFO"
|
| 88 |
+
log_file: str = "training.log" # Routes to logs/ dir in legacy mode, Hydra output dir in Hydra mode
|
| 89 |
+
|
| 90 |
+
# Performance optimization
|
| 91 |
+
batch_size: int = 8
|
| 92 |
+
num_workers: int = 4
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@dataclass
|
| 96 |
+
class HardwareConfig:
|
| 97 |
+
"""Hardware settings"""
|
| 98 |
+
|
| 99 |
+
device: str = "mps"
|
| 100 |
+
gpu_memory_fraction: float = 0.8
|
| 101 |
+
clear_cache_frequency: int = 100
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class ExperimentConfig:
|
| 106 |
+
"""Experiment metadata"""
|
| 107 |
+
|
| 108 |
+
name: str = "novo_replication"
|
| 109 |
+
description: str = "Train ESM-1v VH-based LogisticReg on Boughter, test on Jain"
|
| 110 |
+
tags: list[str] = field(default_factory=lambda: ["baseline", "esm1v", "logreg"])
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@dataclass
|
| 114 |
+
class Config:
|
| 115 |
+
"""Root configuration (complete schema matching current trainer.py)"""
|
| 116 |
+
|
| 117 |
+
model: ModelConfig = field(default_factory=ModelConfig)
|
| 118 |
+
classifier: ClassifierConfig = field(default_factory=ClassifierConfig)
|
| 119 |
+
data: DataConfig = field(default_factory=DataConfig)
|
| 120 |
+
training: TrainingConfig = field(default_factory=TrainingConfig)
|
| 121 |
+
hardware: HardwareConfig = field(default_factory=HardwareConfig)
|
| 122 |
+
experiment: ExperimentConfig = field(default_factory=ExperimentConfig)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ConfigStore registrations REMOVED to fix CLI override bug
|
| 126 |
+
#
|
| 127 |
+
# Root cause: Registering structured configs with the same names as YAML files
|
| 128 |
+
# causes Hydra to prefer ConfigStore over YAML when using package-based config
|
| 129 |
+
# loading (which the console script does). This breaks config group overrides.
|
| 130 |
+
#
|
| 131 |
+
# Known issue: Hydra structured configs strictly validate keys.
|
| 132 |
+
# Overrides adding new keys require proper schema definition or +key syntax with strict mode disabled.# See: https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching
|
| 133 |
+
#
|
| 134 |
+
# The dataclasses above are kept for type hints and validation in code, but are
|
| 135 |
+
# no longer registered with ConfigStore. This allows YAML files to be the single
|
| 136 |
+
# source of truth for configuration.
|
| 137 |
+
#
|
| 138 |
+
# cs = ConfigStore.instance()
|
| 139 |
+
# cs.store(name="config", node=Config)
|
| 140 |
+
# cs.store(group="model", name="esm1v", node=ModelConfig)
|
| 141 |
+
# cs.store(group="classifier", name="logreg", node=ClassifierConfig)
|
| 142 |
+
# cs.store(group="data", name="boughter_jain", node=DataConfig)
|
src/antibody_training_esm/conf/data/boughter_jain.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data source (matches current loaders.py requirements)
|
| 2 |
+
source: local
|
| 3 |
+
dataset_name: boughter_vh
|
| 4 |
+
|
| 5 |
+
# File paths
|
| 6 |
+
train_file: data/train/boughter/canonical/VH_only_boughter_training.csv
|
| 7 |
+
test_file: data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv
|
| 8 |
+
|
| 9 |
+
# Data format options (required by loaders.py)
|
| 10 |
+
# Jain canonical parity file uses 'vh_sequence'; align config to avoid column errors
|
| 11 |
+
sequence_column: sequence
|
| 12 |
+
label_column: label
|
| 13 |
+
max_sequence_length: 1024
|
| 14 |
+
|
| 15 |
+
# Embedding caching (required by trainer.py)
|
| 16 |
+
save_embeddings: true
|
| 17 |
+
embeddings_cache_dir: ./experiments/cache
|
| 18 |
+
|
| 19 |
+
# Fragment metadata (for testing only)
|
| 20 |
+
train_fragment: VH
|
| 21 |
+
test_fragment: VH
|
| 22 |
+
test_assay: ELISA
|
| 23 |
+
test_threshold: 0.5
|
src/antibody_training_esm/conf/hardware/default.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hardware configuration
|
| 2 |
+
# Default to MPS for macOS performance (training/testing); Gradio app handles stability fallback
|
| 3 |
+
device: mps
|
| 4 |
+
gpu_memory_fraction: 0.8
|
| 5 |
+
clear_cache_frequency: 100
|
src/antibody_training_esm/conf/hydra/default.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hydra output directory management
|
| 2 |
+
run:
|
| 3 |
+
dir: experiments/runs/${experiment.name}/${now:%Y-%m-%d_%H-%M-%S}
|
| 4 |
+
|
| 5 |
+
sweep:
|
| 6 |
+
dir: experiments/runs/sweeps/${experiment.name}
|
| 7 |
+
subdir: ${hydra.job.num}
|
| 8 |
+
|
| 9 |
+
job:
|
| 10 |
+
chdir: false # Don't change working directory
|
src/antibody_training_esm/conf/model/esm1v.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: facebook/esm1v_t33_650M_UR90S_1
|
| 2 |
+
revision: main
|
| 3 |
+
# Default to CPU for stability on macOS; override with hardware.device or CLI if desired
|
| 4 |
+
device: ${hardware.device}
|
src/antibody_training_esm/conf/model/esm2_650m.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: facebook/esm2_t33_650M_UR50D
|
| 2 |
+
revision: main
|
| 3 |
+
device: ${hardware.device}
|
src/antibody_training_esm/conf/predict.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- /model: esm1v
|
| 5 |
+
- /classifier: logreg
|
| 6 |
+
- /hardware: default
|
| 7 |
+
- _self_
|
| 8 |
+
|
| 9 |
+
input_file: null
|
| 10 |
+
output_file: "predictions.csv"
|
| 11 |
+
sequence_column: "sequence"
|
| 12 |
+
assay_type: null # Options: "PSR", "ELISA", or null
|
| 13 |
+
threshold: 0.5 # Ignored if assay_type is set
|
| 14 |
+
|
| 15 |
+
gradio:
|
| 16 |
+
server_name: "0.0.0.0"
|
| 17 |
+
server_port: 7860
|
| 18 |
+
share: false
|
| 19 |
+
queue:
|
| 20 |
+
concurrency_limit: 2 # Based on 8GB VRAM (3GB per ESM-1v inference)
|
| 21 |
+
max_size: 10 # Prevents unbounded queue growth
|
| 22 |
+
log_level: INFO
|
| 23 |
+
|
| 24 |
+
hydra:
|
| 25 |
+
job:
|
| 26 |
+
chdir: False
|
src/antibody_training_esm/conf/testing/jain_p5e_s2.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_paths: [experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl]
|
| 2 |
+
data_paths: [data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv]
|
| 3 |
+
sequence_column: vh_sequence
|
| 4 |
+
label_column: label
|
| 5 |
+
output_dir: experiments/benchmarks
|
| 6 |
+
device: cpu
|
| 7 |
+
batch_size: 8
|
src/antibody_training_esm/core/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core ML Module
|
| 3 |
+
|
| 4 |
+
Professional ML components for antibody classification:
|
| 5 |
+
- ESM embedding extraction
|
| 6 |
+
- Binary classification
|
| 7 |
+
- Training pipelines
|
| 8 |
+
- Model serialization (pickle + NPZ+JSON)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from antibody_training_esm.core.classifier import BinaryClassifier
|
| 12 |
+
from antibody_training_esm.core.embeddings import ESMEmbeddingExtractor
|
| 13 |
+
from antibody_training_esm.core.trainer import load_model_from_npz
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"BinaryClassifier",
|
| 17 |
+
"ESMEmbeddingExtractor",
|
| 18 |
+
"load_model_from_npz",
|
| 19 |
+
]
|
src/antibody_training_esm/core/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (705 Bytes). View file
|
|
|
src/antibody_training_esm/core/__pycache__/classifier.cpython-312.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
src/antibody_training_esm/core/__pycache__/classifier_factory.cpython-312.pyc
ADDED
|
Binary file (4.66 kB). View file
|
|
|