diff --git a/README.md b/README.md index 359a61c088511624da0998ee14572d64198f4551..34d8a24a0614a3dad50c1f26112fad32cbc11b06 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,67 @@ --- -title: Antibody Predictor -emoji: 🐨 -colorFrom: pink -colorTo: indigo +title: Antibody Non-Specificity Predictor +emoji: 🧬 +colorFrom: blue +colorTo: green sdk: gradio -sdk_version: 6.0.0 -app_file: app.py +sdk_version: "5.0.0" +app_file: spaces/app.py pinned: false +license: mit +tags: + - antibody + - protein + - ESM + - gradio + - polyreactivity + - machine-learning --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# 🧬 Antibody Non-Specificity Predictor + +Predict antibody polyreactivity (non-specificity) from Variable Heavy (VH) or Variable Light (VL) sequences using ESM-1v protein language models. + +## Model + +- **Architecture:** ESM-1v (650M parameters) + Logistic Regression +- **Training Data:** Boughter dataset (914 antibodies, ELISA polyreactivity) +- **Methodology:** Sakhnini et al. (2025) - Prediction of Antibody Non-Specificity using PLMs + +## Usage + +1. Paste your antibody VH or VL amino acid sequence +2. Click "šŸ”¬ Predict Non-Specificity" +3. Get prediction (specific vs non-specific) + probability + +## Supported Input + +- **Valid characters:** Standard amino acids (ACDEFGHIKLMNPQRSTVWY) +- **Max length:** 2000 amino acids +- **Auto-cleaning:** Lowercase automatically converted to uppercase + +## Examples + +The app includes example sequences: +- Standard VH (128aa) +- Standard VL (107aa) +- Short VH (Herceptin-like) + +## Citation + +If you use this tool in your research, please cite: + +```bibtex +@article{sakhnini2025antibody, + title={Prediction of Antibody Non-Specificity using Protein Language Models}, + author={Sakhnini, et al.}, + year={2025} +} +``` + +## Repository + +Full source code: [antibody_training_pipeline_ESM](https://github.com/The-Obstacle-Is-The-Way/antibody_training_pipeline_ESM) + +## License + +MIT License - See repository for details diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..10aa1a26607311c82267315edd0ced4094798ed9 --- /dev/null +++ b/app.py @@ -0,0 +1,152 @@ +""" +Hugging Face Spaces Gradio App for Antibody Non-Specificity Prediction + +Simplified deployment version (no Hydra, no complex dependencies). +Works on HF Spaces free CPU tier. + +Local app (src/antibody_training_esm/cli/app.py) remains unchanged. +""" + +import logging +import os + +import gradio as gr +import torch +from pydantic import ValidationError + +from antibody_training_esm.core.prediction import Predictor +from antibody_training_esm.models.prediction import PredictionRequest + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# HF Spaces environment detection +IS_HF_SPACE = os.getenv("SPACE_ID") is not None + +# Model path (either local or downloaded from HF Hub) +MODEL_PATH = os.getenv( + "MODEL_PATH", "experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl" +) + +# ESM model name +MODEL_NAME = "facebook/esm1v_t33_650M_UR90S_1" + +# Force CPU for HF Spaces free tier +DEVICE = "cpu" + +# Load model globally (HF Spaces best practice) +logger.info(f"Loading model from {MODEL_PATH}...") +predictor = Predictor( + model_name=MODEL_NAME, classifier_path=MODEL_PATH, device=DEVICE, config_path=None +) + +# Warm up model +try: + logger.info("Warming up model...") + predictor.predict_single("QVQL") + logger.info("Model ready!") +except Exception as e: + logger.warning(f"Warmup failed (non-fatal): {e}") + + +def predict_sequence(sequence: str) -> tuple[str, str]: + """ + Prediction function for Gradio interface. + + Args: + sequence: Antibody amino acid sequence + + Returns: + Tuple of (prediction, probability) + """ + try: + # Validate with Pydantic + request = PredictionRequest(sequence=sequence) + + # Log request + logger.info(f"Processing sequence: length={len(request.sequence)}") + + # Predict + result = predictor.predict_single(request) + + # Format probability + prob_percent = f"{result.probability:.1%}" + + return result.prediction, prob_percent + + except ValidationError as e: + # User-friendly error message + error_msg = e.errors()[0]["msg"] + raise gr.Error(error_msg) from e + except torch.cuda.OutOfMemoryError as e: + logger.error("GPU OOM during inference") + raise gr.Error( + "Server overloaded (GPU OOM). Please try again in a moment." + ) from e + except Exception as e: + logger.exception("Unexpected prediction failure") + raise gr.Error(f"Prediction failed: {str(e)}") from e + + +# Example sequences +examples = [ + [ + "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYNMHWVRQAPGQGLEWMGGIYPGDSDTRYSPSFQGQVTISADKSISTAYLQWSSLKASDTAMYYCARSTYYGGDWYFNVWGQGTLVTVSS" + ], # Standard VH + [ + "DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPLTFGGGTKVEIK" + ], # Standard VL + [ + "EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARSWGQGTLVTVSS" + ], # Short VH +] + +# Create Gradio interface +iface = gr.Interface( + fn=predict_sequence, + inputs=gr.TextArea( + lines=7, + max_lines=20, + max_length=2000, + label="Antibody Sequence (VH or VL)", + placeholder="Paste amino acid sequence here (e.g., QVQL...)", + info="Supported characters: Standard amino acids (ACDEFGHIKLMNPQRSTVWY).", + show_copy_button=True, + ), + outputs=[ + gr.Textbox(label="Prediction", show_copy_button=True), + gr.Textbox(label="Probability of Non-Specificity", show_copy_button=True), + ], + title="🧬 Antibody Non-Specificity Predictor", + description=( + "Predict antibody polyreactivity (non-specificity) from Variable Heavy (VH) " + "or Variable Light (VL) sequences using ESM-1v protein language models.\n\n" + "**Model:** ESM-1v (650M parameters) + Logistic Regression\n" + "**Training:** Boughter dataset (914 antibodies, ELISA polyreactivity)\n" + "**Citation:** Sakhnini et al. (2025) - Prediction of Antibody Non-Specificity using PLMs" + ), + article=( + f"**Model:** {MODEL_NAME}\n" + f"**Device:** {DEVICE}\n" + f"**Environment:** {'Hugging Face Spaces' if IS_HF_SPACE else 'Local'}" + ), + examples=examples, + cache_examples=False, # Don't cache on HF Spaces (saves disk) + flagging_mode="never", + analytics_enabled=False, + submit_btn="šŸ”¬ Predict Non-Specificity", + clear_btn="šŸ—‘ļø Clear", +) + +# Enable queue for concurrency +iface.queue(default_concurrency_limit=2, max_size=10) + +# Launch app +if __name__ == "__main__": + iface.launch( + server_name="0.0.0.0", # Required for HF Spaces + server_port=7860, + share=False, + show_api=False, # No public REST API + ) diff --git a/experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl b/experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl new file mode 100644 index 0000000000000000000000000000000000000000..eeafe1c91a8b407c364275c0866e8a81832fea87 --- /dev/null +++ b/experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4f77cadfd0ccf3a12c24ce142a91c82b4481d5153a0af662ac4b05a78ef6670 +size 11314 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..30ec224ee0b267d856afd6180938dbecf4201e6a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,215 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/antibody_training_esm"] +include = [ + "src/antibody_training_esm/conf/**/*.yaml", + "src/antibody_training_esm/conf/**/*.py", +] + +[tool.hatch.build.targets.sdist] +# Source distribution must include all source files + configs +include = [ + "src/antibody_training_esm/**/*.py", + "src/antibody_training_esm/conf/**/*.yaml", + "tests/**/*.py", + "README.md", + "pyproject.toml", + "LICENSE", +] + +[project] +name = "antibody-training-esm" +version = "0.7.0" +description = "Professional antibody training pipeline using ESM protein language models" +license = {text = "Apache-2.0"} +requires-python = ">=3.12" +dependencies = [ + "authlib>=1.6.5", + "biopython>=1.80", + "brotli>=1.2.0", + "datasets>=4.2.0", + "h2>=4.3.0", + "hydra-core>=1.3.2", + "jupyterlab>=4.4.9", + "matplotlib>=3.7.0", + "more-itertools", + "numpy>=1.24.0", + "pandas>=2.0.0", + "plotly", + "pyparsing>=3.0.0", + "PyYAML>=6.0.0", + "riot_na", + "scikit-learn>=1.3.0", + "scipy>=1.10.0", + "seaborn>=0.12.0", + "torch>=2.6.0", + "tqdm>=4.65.0", + "transformers>=4.30.0", + "xgboost>=2.0.0", + "gradio>=4.0.0", +] + +[project.optional-dependencies] +validation = [ + "pydantic>=2.10.0", # Stable v2 release + "pydantic-settings>=2.6.0", # For future config management + "pandera>=0.20.0", # Phase 3: Data Integrity +] +dev = [ + # Testing + "pytest>=8.3.0", + "pytest-cov>=6.0.0", + "pytest-xdist>=3.6.0", + "pytest-sugar>=1.0.0", + + # Linting & Formatting + "ruff>=0.8.0", + + # Type Checking + "mypy>=1.13.0", + "pandas-stubs>=2.2.0", + + # Security + "bandit[toml]>=1.7.0", + + # Pre-commit + "pre-commit>=4.0.0", + + # Documentation + "mkdocs>=1.6.0", + "mkdocs-material>=9.5.0", + "mkdocstrings[python]>=0.26.0", + "mkdocs-gen-files>=0.5.0", + "mkdocs-literate-nav>=0.6.0", + "mkdocs-section-index>=0.3.0", + "pymdown-extensions>=10.0.0", +] + +[project.scripts] +# Point directly to Hydra-decorated function to enable config group overrides +# (antibody-train model=esm2_650m classifier=xgboost now works correctly) +antibody-train = "antibody_training_esm.core.trainer:main" +antibody-test = "antibody_training_esm.cli.test:main" +antibody-preprocess = "antibody_training_esm.cli.preprocess:main" +antibody-predict = "antibody_training_esm.cli.predict:main" +antibody-app = "antibody_training_esm.cli.app:main" + +[tool.ruff] +target-version = "py312" +line-length = 88 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade + "ARG", # flake8-unused-arguments + "SIM", # flake8-simplify +] +ignore = [ + "E501", # line too long (handled by formatter) +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] +"tests/**/*" = ["ARG"] +"experiments/**/*" = ["ALL"] +"reference_repos/**/*" = ["ALL"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" + +[tool.mypy] +python_version = "3.12" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +ignore_missing_imports = true +exclude = [ + "experiments/", + "reference_repos/", + "site/", # MkDocs generated documentation + "tests/unit/cli/test_train.py", # Legacy CLI tests (deprecated) +] + +[tool.pytest.ini_options] +# Pytest Configuration (canonical source - pytest.ini deleted for single source of truth) +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = [ + # Output formatting + "-v", + "--tb=short", + "--strict-markers", + "-ra", + # Coverage reporting + "--cov=src/antibody_training_esm", + "--cov-report=html", + "--cov-report=term-missing", + # Performance + "--maxfail=10", +] +markers = [ + "unit: Unit tests (fast, no I/O) - Core business logic", + "integration: Integration tests (medium speed, some I/O) - Component interactions", + "e2e: End-to-end tests (slow, full pipeline) - Full workflows", + "slow: Tests that take >1s to run", + "gpu: Tests that require GPU (skip in CI with: -m 'not gpu')", + "legacy: Legacy tests for backward compatibility (deprecated, will be removed)", +] +filterwarnings = [ + # sklearn deprecation warnings + "ignore:.*__sklearn_tags__.*:DeprecationWarning:sklearn.utils._tags", + # sklearn convergence warnings (expected with small test datasets) + "ignore:.*lbfgs failed to converge.*:sklearn.exceptions.ConvergenceWarning", + "ignore:.*lbfgs failed to converge.*:UserWarning:sklearn.linear_model._logistic", + # sklearn scoring warnings (expected when testing edge cases) + "ignore:.*Scoring failed.*:UserWarning:sklearn.model_selection._validation", + # sklearn undefined metric warnings (expected with edge case test data) + "ignore:.*Precision is ill-defined.*:sklearn.exceptions.UndefinedMetricWarning", + "ignore:.*Precision is ill-defined.*:UserWarning:sklearn.metrics._classification", + # pytest collection warnings (TestConfig is a dataclass, not a test class) + "ignore:.*cannot collect test class.*TestConfig.*:pytest.PytestCollectionWarning", + # General deprecation warnings + "ignore::DeprecationWarning", + "ignore::PendingDeprecationWarning", +] + +[tool.coverage.run] +source = ["src"] +omit = [ + "tests/*", + "experiments/*", + "reference_repos/*", + "**/__pycache__/*", + ".venv/*", + "**/conftest.py", +] +branch = true + +[tool.coverage.report] +precision = 2 +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + +[dependency-groups] +dev = [ + "openpyxl>=3.1.5", + "types-pyyaml>=6.0.12.20250915", +] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..72857ff66d248bc74c730fba023695841982bb0c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +# Hugging Face Spaces Requirements +# Minimal dependencies for antibody prediction demo + +# Core ML +torch>=2.0.0 +transformers>=4.30.0 +scikit-learn>=1.3.0 +scipy>=1.10.0 +joblib>=1.3.0 + +# Data handling +pandas>=2.0.0 +numpy>=1.24.0 + +# Configuration +omegaconf>=2.3.0 + +# Validation +pydantic>=2.0.0 + +# Gradio UI +gradio>=5.0.0 + +# Progress bars +tqdm>=4.65.0 + +# Install local package (antibody_training_esm) +. diff --git a/src/antibody_training_esm/__init__.py b/src/antibody_training_esm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/antibody_training_esm/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e188d7aea644b66e24b49dc01a727b34355ce510 Binary files /dev/null and b/src/antibody_training_esm/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/__pycache__/settings.cpython-312.pyc b/src/antibody_training_esm/__pycache__/settings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d5d178c15765724f114ac22753f08b2af641d37 Binary files /dev/null and b/src/antibody_training_esm/__pycache__/settings.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/__init__.py b/src/antibody_training_esm/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23efd94f566a8da71474d82f5f688ec83b0c57ff --- /dev/null +++ b/src/antibody_training_esm/cli/__init__.py @@ -0,0 +1,10 @@ +""" +CLI Module + +Professional command-line interfaces for antibody training pipeline: +- antibody-train: Model training +- antibody-test: Model evaluation +- antibody-preprocess: Dataset preprocessing +""" + +__all__ = [] diff --git a/src/antibody_training_esm/cli/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/cli/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..128495e4d8c3fd503977f453ed1be1c7d40daf6b Binary files /dev/null and b/src/antibody_training_esm/cli/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/__pycache__/app.cpython-312.pyc b/src/antibody_training_esm/cli/__pycache__/app.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c210921bedf08c24117118c89804c20d0f10b11 Binary files /dev/null and b/src/antibody_training_esm/cli/__pycache__/app.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/__pycache__/predict.cpython-312.pyc b/src/antibody_training_esm/cli/__pycache__/predict.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84c67b454048c400df143f4444136e5b9ff25f54 Binary files /dev/null and b/src/antibody_training_esm/cli/__pycache__/predict.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/__pycache__/preprocess.cpython-312.pyc b/src/antibody_training_esm/cli/__pycache__/preprocess.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9979782a64f56f54fb49a2fc27279566cc232ef0 Binary files /dev/null and b/src/antibody_training_esm/cli/__pycache__/preprocess.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/__pycache__/test.cpython-312.pyc b/src/antibody_training_esm/cli/__pycache__/test.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d5a56df0a2ef2050116bc42e401422c7c7c9c88 Binary files /dev/null and b/src/antibody_training_esm/cli/__pycache__/test.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/__pycache__/train.cpython-312.pyc b/src/antibody_training_esm/cli/__pycache__/train.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae1be672614c1ccbdd4c3106273889895f98d3a9 Binary files /dev/null and b/src/antibody_training_esm/cli/__pycache__/train.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/app.py b/src/antibody_training_esm/cli/app.py new file mode 100644 index 0000000000000000000000000000000000000000..230f909dd1b1cebbc95cd76f696f40fced2de0eb --- /dev/null +++ b/src/antibody_training_esm/cli/app.py @@ -0,0 +1,197 @@ +""" +This module contains the Gradio app for the antibody non-specificity prediction pipeline. +""" + +import logging +import platform +from pathlib import Path + +import gradio as gr +import hydra +import torch +from omegaconf import DictConfig +from pydantic import ValidationError + +from antibody_training_esm.core.prediction import Predictor +from antibody_training_esm.models.prediction import PredictionRequest + +# Configure logging +logger = logging.getLogger(__name__) + + +def launch_gradio_app(cfg: DictConfig) -> None: + """ + Launches the Gradio web UI for antibody prediction. + + This function sets up a Gradio interface that allows users to input an + antibody sequence and receive a prediction for its non-specificity. + + Args: + cfg: The Hydra configuration object. + """ + # Set log level from config + logging.basicConfig( + level=getattr(logging, cfg.gradio.log_level.upper(), logging.INFO) + ) + + # Robust Device & Threading Configuration + # ------------------------------------------------------------------------- + # 1. Determine the optimal device for inference + # - Prefer CUDA if available (Linux/Windows GPU boxes) + # - Force CPU on macOS if MPS is detected to avoid Gradio+MPS SegFaults + # - Default to configured value otherwise + device = cfg.model.get("device", "cpu") + + if platform.system() == "Darwin" and device == "mps": + logger.warning( + "macOS detected. Forcing CPU for Gradio app stability (MPS workaround)." + ) + device = "cpu" + + # 2. Configure Threading to prevent OpenMP SegFaults on macOS + # - On macOS/CPU, PyTorch's OpenMP runtime can crash inside Gradio threads. + # - We restrict it to 1 thread to ensure stability. + # - Linux/CUDA systems remain untouched and can use full parallelism. + if platform.system() == "Darwin" and device == "cpu": + logger.warning( + "macOS/CPU detected. Setting torch.set_num_threads(1) to prevent OpenMP crashes." + ) + torch.set_num_threads(1) + + if cfg.classifier.path is None: + raise ValueError( + "Classifier path must be specified via command-line override:\n" + " classifier.path=experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl" + ) + classifier_path = Path(cfg.classifier.path) + if not classifier_path.exists(): + raise FileNotFoundError( + f"Classifier file not found at {classifier_path}. " + "Train a model (e.g., `make train`) or download a published checkpoint first." + ) + + # Instantiate the predictor + config_path = getattr(cfg.classifier, "config_path", None) + predictor = Predictor( + model_name=cfg.model.name, + classifier_path=cfg.classifier.path, + device=device, + config_path=config_path, + ) + + # Warm-up: Run a dummy prediction to load the model into memory eagerly + try: + logger.info("Warming up model with dummy prediction...") + predictor.predict_single("QVQL") + logger.info("Model warmed up and ready.") + except Exception as e: + logger.warning(f"Model warm-up failed (non-fatal): {e}") + + def predict_sequence(sequence: str) -> tuple[str, str]: + """ + Prediction function for the Gradio interface. + + Args: + sequence: The antibody sequence to predict. + + Returns: + A tuple containing the prediction string and the formatted probability. + """ + try: + # Validate with Pydantic (replaces old validate_input) + request = PredictionRequest(sequence=sequence) + + # Log request (observability) + logger.info(f"Processing: length={len(request.sequence)}") + + # Predict (returns PydanticResult) + result = predictor.predict_single(request) + + # Format probability + prob_percent = f"{result.probability:.1%}" + + return result.prediction, prob_percent + + except ValidationError as e: + # Extract first error message for user-friendly display + error_msg = e.errors()[0]["msg"] + raise gr.Error(error_msg) from e + except torch.cuda.OutOfMemoryError as e: + logger.error("GPU OOM during inference") + raise gr.Error( + "Server overloaded (GPU OOM). Please try again in a moment." + ) from e + except Exception as e: + logger.exception("Unexpected prediction failure") + raise gr.Error(f"Prediction failed: {str(e)}") from e + + # Example sequences (Diverse set) + examples = [ + [ + "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYNMHWVRQAPGQGLEWMGGIYPGDSDTRYSPSFQGQVTISADKSISTAYLQWSSLKASDTAMYYCARSTYYGGDWYFNVWGQGTLVTVSS" + ], # Standard VH + [ + "DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPLTFGGGTKVEIK" + ], # Standard VL + [ + "EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARSWGQGTLVTVSS" + ], # Short VH (Herceptin-like) + ] + + # Create the Gradio interface + iface = gr.Interface( + fn=predict_sequence, + inputs=gr.TextArea( + lines=7, + max_lines=20, + max_length=2000, + label="Antibody Sequence (VH or VL)", + placeholder="Paste amino acid sequence here (e.g., QVQL...)", + info="Supported characters: Standard amino acids (ACDEFGHIKLMNPQRSTVWY).", + show_copy_button=True, + ), + outputs=[ + gr.Textbox(label="Prediction", show_copy_button=True), + gr.Textbox(label="Probability of Non-Specificity", show_copy_button=True), + ], + title="Antibody Non-Specificity Predictor", + description=( + "Enter an antibody Variable Heavy (VH) or Variable Light (VL) sequence " + "to predict its non-specificity (polyreactivity)." + ), + article=f"Model: {cfg.model.name} | Device: {device}", + examples=examples, + cache_examples=True, + flagging_mode="never", + analytics_enabled=False, + submit_btn="Predict Non-Specificity", + ) + + # Enable queueing for concurrency management + """ + Queue Configuration: + - concurrency_limit: Based on available VRAM (approx 3GB per ESM-1v inference). + - max_size: Prevents unbounded queue growth under load. + """ + iface.queue( + default_concurrency_limit=cfg.gradio.queue.concurrency_limit, + max_size=cfg.gradio.queue.max_size, + ) + + # Launch the app with hardened settings + iface.launch( + server_name=cfg.gradio.server_name, + server_port=cfg.gradio.server_port, + share=cfg.gradio.share, + show_api=False, + ) + + +@hydra.main(config_path="../conf", config_name="predict", version_base=None) +def main(cfg: DictConfig) -> None: + """Main function to run the Gradio app.""" + launch_gradio_app(cfg) + + +if __name__ == "__main__": + main() diff --git a/src/antibody_training_esm/cli/predict.py b/src/antibody_training_esm/cli/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..a5f7c02004901b141d3b77fbe7d1a1a0469fc864 --- /dev/null +++ b/src/antibody_training_esm/cli/predict.py @@ -0,0 +1,116 @@ +import sys +from pathlib import Path +from typing import cast + +import hydra +import pandas as pd +from omegaconf import DictConfig +from pydantic import ValidationError + +from antibody_training_esm.core.config import SEQUENCE_PREVIEW_LENGTH +from antibody_training_esm.core.prediction import Predictor, run_prediction +from antibody_training_esm.models.prediction import AssayType, PredictionRequest + + +def predict_sequence_cli( + sequence: str, threshold: float, assay_type: AssayType | None, cfg: DictConfig +) -> None: + """CLI prediction with Pydantic validation.""" + config_path = getattr(cfg.classifier, "config_path", None) + + # Instantiate predictor (loading model) + try: + predictor = Predictor( + model_name=cfg.model.name, + classifier_path=cfg.classifier.path, + config_path=config_path, + ) + except Exception as e: + print(f"Error loading model: {e}") + sys.exit(1) + + try: + request = PredictionRequest( + sequence=sequence, + threshold=threshold, + assay_type=assay_type, + ) + result = predictor.predict_single(request) + + # Print formatted output + print( + f"Sequence: {result.sequence[:SEQUENCE_PREVIEW_LENGTH]}..." + if len(result.sequence) > SEQUENCE_PREVIEW_LENGTH + else f"Sequence: {result.sequence}" + ) + print(f"Prediction: {result.prediction}") + print(f"Probability: {result.probability:.2%}") + + except ValidationError as e: + print("āŒ Validation Error:") + for error in e.errors(): + # loc is a tuple, e.g. ('sequence',) + loc = error["loc"][0] if error["loc"] else "root" + print(f" - {loc}: {error['msg']}") + sys.exit(1) + + +@hydra.main(config_path="../conf", config_name="predict", version_base=None) +def main(cfg: DictConfig) -> None: + """Main function to run the prediction CLI.""" + + # Check for single sequence prediction mode + sequence = getattr(cfg, "sequence", None) + if sequence: + threshold = getattr(cfg, "threshold", 0.5) + assay_type = cast(AssayType | None, getattr(cfg, "assay_type", None)) + predict_sequence_cli(sequence, threshold, assay_type, cfg) + return + + # Validate required arguments for batch mode + if cfg.input_file is None: + raise ValueError( + "Input file must be specified via command-line override: `input_file=...`" + ) + + if cfg.classifier.path is None: + raise ValueError( + "Classifier path must be specified via command-line override:\n" + " classifier.path=experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl\n" + " # OR for production models (.npz):\n" + " classifier.path=experiments/.../model.npz classifier.config_path=.../model_config.json\n" + "\nExample usage:\n" + " uv run antibody-predict \\\n" + " input_file=data/test.csv \\\n" + " output_file=predictions.csv \\\n" + " classifier.path=path/to/model.pkl" + ) + classifier_path = Path(cfg.classifier.path) + if not classifier_path.exists(): + raise FileNotFoundError( + f"Classifier file not found at {classifier_path}. " + "Train a model (e.g., `make train`) or download a published checkpoint first." + ) + + try: + # Load input data + input_df = pd.read_csv(cfg.input_file) + + # Run prediction + output_df = run_prediction(input_df, cfg) + + # Save output data + output_df.to_csv(cfg.output_file, index=False) + + print(f"Predictions saved to {cfg.output_file}") + + except FileNotFoundError: + print(f"Error: Input file not found at {cfg.input_file}") + exit(1) + except Exception as e: + print(f"An error occurred: {e}") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/antibody_training_esm/cli/preprocess.py b/src/antibody_training_esm/cli/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..fd7ddee406d6061a678714dff57a8c9c7667334d --- /dev/null +++ b/src/antibody_training_esm/cli/preprocess.py @@ -0,0 +1,84 @@ +""" +Preprocessing CLI + +Professional command-line interface for dataset preprocessing. +""" + +import argparse +import sys + + +def main() -> int: + """ + Main entry point for preprocessing CLI. + + This CLI does NOT run preprocessing - it only provides guidance on which + preprocessing scripts to use. Preprocessing is handled by specialized + scripts that are the Single Source of Truth (SSOT). + """ + parser = argparse.ArgumentParser( + description="Antibody dataset preprocessing guidance", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +NOTE: This CLI does NOT run preprocessing. It provides guidance on which +preprocessing scripts to use. Each dataset has unique requirements and the +scripts maintain bit-for-bit parity with published methods. + """, + ) + + parser.add_argument( + "--dataset", + "-d", + type=str, + required=True, + choices=["jain", "harvey", "shehata", "boughter"], + help="Dataset to get preprocessing guidance for", + ) + + args = parser.parse_args() + + try: + print("\nāš ļø The 'antibody-preprocess' CLI is not implemented") + print( + "\nDataset preprocessing is handled by specialized scripts, not this CLI." + ) + print( + "These scripts are the authoritative source of truth for data transformation." + ) + print(f"\nFor {args.dataset} dataset, use:") + + script_paths = { + "jain": "preprocessing/jain/step2_preprocess_p5e_s2.py", + "harvey": "preprocessing/harvey/step2_extract_fragments.py", + "shehata": "preprocessing/shehata/step2_extract_fragments.py", + "boughter": "preprocessing/boughter/stage2_stage3_annotation_qc.py", + } + + script = script_paths.get(args.dataset) + if script: + print(f" python {script}") + + print("\nWhy use scripts instead of this CLI?") + print(" • Scripts are Single Source of Truth (SSOT) for preprocessing") + print( + " • Each dataset has unique requirements (DNA translation, PSR thresholds, etc.)" + ) + print(" • Scripts maintain bit-for-bit parity with published methods") + print(" • CLI is for loading preprocessed data, not creating it") + + print("\nFor more information:") + print(" • See src/antibody_training_esm/datasets/README.md") + print(" • See docs/boughter/boughter_data_sources.md (dataset-specific)") + + return 0 + + except KeyboardInterrupt: + print("\nāŒ Error: Interrupted by user", file=sys.stderr) + return 1 + except Exception as e: + print(f"\nāŒ Error: {e}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/antibody_training_esm/cli/test.py b/src/antibody_training_esm/cli/test.py new file mode 100644 index 0000000000000000000000000000000000000000..7616e6236bcacd928fc841f78d3131a4d3c0ec90 --- /dev/null +++ b/src/antibody_training_esm/cli/test.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +""" +Test CLI for Antibody Classification Pipeline + +Professional command-line interface for testing trained antibody classifiers: +1. Load trained models from pickle files +2. Evaluate on test datasets with performance metrics +3. Generate confusion matrices and comprehensive logging + +Usage: + antibody-test --model experiments/checkpoints/antibody_classifier.pkl --data sample_data.csv + antibody-test --config test_config.yaml + antibody-test --model m1.pkl m2.pkl --data d1.csv d2.csv +""" + +import argparse +import sys + +from antibody_training_esm.cli.testing.config import ( + TestConfig, + create_sample_test_config, + load_config_file, +) +from antibody_training_esm.cli.testing.tester import ModelTester + + +def main() -> int: + """Main entry point for antibody-test CLI""" + parser = argparse.ArgumentParser( + description="Testing for antibody classification models", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Test single model on single dataset (auto-detects threshold from dataset name) + antibody-test --model experiments/checkpoints/antibody_classifier.pkl --data sample_data.csv + + # Test on PSR dataset with auto-detected threshold (0.5495 for Harvey/Shehata) + antibody-test --model model.pkl --data data/test/harvey/fragments/VHH_only_harvey.csv + + # Test multiple models on multiple datasets + antibody-test --model experiments/checkpoints/model1.pkl experiments/checkpoints/model2.pkl --data dataset1.csv dataset2.csv + + # Use configuration file + antibody-test --config test_config.yaml + + # Override device, batch size, and threshold + antibody-test --config test_config.yaml --device cuda --batch-size 64 --threshold 0.6 + + # Create sample configuration + antibody-test --create-config + """, + ) + + parser.add_argument( + "--model", nargs="+", help="Path(s) to trained model pickle files" + ) + parser.add_argument("--data", nargs="+", help="Path(s) to test dataset CSV files") + parser.add_argument("--config", help="Path to test configuration YAML file") + parser.add_argument( + "--output-dir", + default="./experiments/benchmarks", + help="Output directory for results", + ) + parser.add_argument( + "--device", + choices=["cpu", "cuda", "mps"], + help="Device to use for inference (overrides config)", + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch size for embedding extraction (overrides config)", + ) + parser.add_argument( + "--threshold", + type=float, + help="Manual decision threshold override (default: auto-detect from dataset name). " + "Use 0.5 for ELISA datasets (Boughter, Jain) or 0.5495 for PSR datasets (Harvey, Shehata).", + ) + parser.add_argument( + "--sequence-column", + type=str, + help="Column name for sequences in dataset (default: 'sequence', overrides config)", + ) + parser.add_argument( + "--label-column", + type=str, + help="Column name for labels in dataset (default: 'label', overrides config)", + ) + parser.add_argument( + "--create-config", action="store_true", help="Create sample configuration file" + ) + + args = parser.parse_args() + + # Create sample config if requested + if args.create_config: + create_sample_test_config() + return 0 + + # Load configuration + if args.config: + config = load_config_file(args.config) + else: + if not args.model or not args.data: + parser.error("Either --config or both --model and --data must be specified") + + config = TestConfig( + model_paths=args.model, data_paths=args.data, output_dir=args.output_dir + ) + + # Override config with command line arguments + if args.device: + config.device = args.device + if args.batch_size: + config.batch_size = args.batch_size + if args.threshold: + config.threshold = args.threshold + if args.sequence_column: + config.sequence_column = args.sequence_column + if args.label_column: + config.label_column = args.label_column + + # Run testing + try: + tester = ModelTester(config) + results = tester.run_comprehensive_test() + + print(f"\n{'=' * 60}") + print("TESTING COMPLETED SUCCESSFULLY!") + print(f"{'=' * 60}") + print(f"Results saved to: {config.output_dir}") + + # Print summary + for dataset_name, dataset_results in results.items(): + print(f"\nDataset: {dataset_name}") + print("-" * 40) + for model_name, model_results in dataset_results.items(): + print(f"Model: {model_name}") + if "test_scores" in model_results: + for metric, value in model_results["test_scores"].items(): + print(f" {metric}: {value:.4f}") + + return 0 + + except KeyboardInterrupt: + print("Error during testing: Interrupted by user", file=sys.stderr) + return 1 + except Exception as e: + print(f"Error during testing: {e}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/antibody_training_esm/cli/testing/__init__.py b/src/antibody_training_esm/cli/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8dbbb6a678f11913e90e82e65fd827fddcf4e48 --- /dev/null +++ b/src/antibody_training_esm/cli/testing/__init__.py @@ -0,0 +1 @@ +"""Test CLI package.""" diff --git a/src/antibody_training_esm/cli/testing/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/cli/testing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9031e6244574938f2f3cc772eacaf796d6d23ea9 Binary files /dev/null and b/src/antibody_training_esm/cli/testing/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/testing/__pycache__/config.cpython-312.pyc b/src/antibody_training_esm/cli/testing/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ceee2f95447771fb719852ae9dcb2f9acdb2169 Binary files /dev/null and b/src/antibody_training_esm/cli/testing/__pycache__/config.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/testing/__pycache__/data.cpython-312.pyc b/src/antibody_training_esm/cli/testing/__pycache__/data.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0b1fd4066e0096f9c1dcc83062d8a9267882aea Binary files /dev/null and b/src/antibody_training_esm/cli/testing/__pycache__/data.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/testing/__pycache__/evaluation.cpython-312.pyc b/src/antibody_training_esm/cli/testing/__pycache__/evaluation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aef6bab3d66841cc68231d1a436708d9d1f6fa3 Binary files /dev/null and b/src/antibody_training_esm/cli/testing/__pycache__/evaluation.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/testing/__pycache__/tester.cpython-312.pyc b/src/antibody_training_esm/cli/testing/__pycache__/tester.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..335e143b20b0de943abf65ca5910242cf4c7eaea Binary files /dev/null and b/src/antibody_training_esm/cli/testing/__pycache__/tester.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/testing/__pycache__/visualization.cpython-312.pyc b/src/antibody_training_esm/cli/testing/__pycache__/visualization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06bdb24f564d2e975ca74ca5e63e0b67cc25a78b Binary files /dev/null and b/src/antibody_training_esm/cli/testing/__pycache__/visualization.cpython-312.pyc differ diff --git a/src/antibody_training_esm/cli/testing/config.py b/src/antibody_training_esm/cli/testing/config.py new file mode 100644 index 0000000000000000000000000000000000000000..58fd18a53c4493ccb0ba866c50f25d73f257102e --- /dev/null +++ b/src/antibody_training_esm/cli/testing/config.py @@ -0,0 +1,62 @@ +"""Configuration management for the testing pipeline.""" + +from dataclasses import dataclass + +import yaml + +from antibody_training_esm.core.config import DEFAULT_BATCH_SIZE + + +@dataclass +class TestConfig: + """Configuration for testing pipeline""" + + model_paths: list[str] + data_paths: list[str] + sequence_column: str = "sequence" # Column name for sequences in dataset + label_column: str = "label" # Column name for labels in dataset + output_dir: str = "./experiments/benchmarks" + metrics: list[str] | None = None + save_predictions: bool = True + batch_size: int = DEFAULT_BATCH_SIZE # Batch size for embedding extraction + device: str = "mps" # Device to use for inference [cuda, cpu, mps] - MUST match training config + threshold: float | None = ( + None # Manual threshold override (None = auto-detect from dataset name) + ) + + def __post_init__(self) -> None: + if self.metrics is None: + self.metrics = [ + "accuracy", + "precision", + "recall", + "f1", + "roc_auc", + "pr_auc", + ] + + +def load_config_file(config_path: str) -> TestConfig: + """Load test configuration from YAML file""" + with open(config_path) as f: + config_dict = yaml.safe_load(f) + + return TestConfig(**config_dict) + + +def create_sample_test_config() -> None: + """Create a sample test configuration file""" + sample_config = { + "model_paths": ["./experiments/checkpoints/antibody_classifier.pkl"], + "data_paths": ["./sample_data.csv"], + "sequence_column": "sequence", + "label_column": "label", + "output_dir": "./experiments/benchmarks", + "metrics": ["accuracy", "precision", "recall", "f1", "roc_auc", "pr_auc"], + "save_predictions": True, + } + + with open("test_config.yaml", "w") as f: + yaml.dump(sample_config, f, default_flow_style=False) + + print("Sample test configuration created: test_config.yaml") diff --git a/src/antibody_training_esm/cli/testing/data.py b/src/antibody_training_esm/cli/testing/data.py new file mode 100644 index 0000000000000000000000000000000000000000..050d34ea217944b1ba8bba8223c7b5f4bd1b75cf --- /dev/null +++ b/src/antibody_training_esm/cli/testing/data.py @@ -0,0 +1,73 @@ +"""Dataset loading and validation utilities.""" + +import logging +import os + +import pandas as pd + +from antibody_training_esm.cli.testing.config import TestConfig + +logger = logging.getLogger(__name__) + + +def load_dataset(data_path: str, config: TestConfig) -> tuple[list[str], list[int]]: + """ + Load dataset from CSV file using configured column names. + + Args: + data_path: Path to the CSV file. + config: Test configuration object containing column names. + + Returns: + Tuple of (sequences, labels). + """ + logger.info(f"Loading dataset from {data_path}") + + if not os.path.exists(data_path): + raise FileNotFoundError(f"Dataset file not found: {data_path}") + + # Defensive: Handle legacy files with comment headers + # New files (post-HF cleanup) are standard CSVs without comments + df = pd.read_csv(data_path, comment="#") + + sequence_col = config.sequence_column + label_col = config.label_column + + if sequence_col not in df.columns: + raise ValueError( + f"Sequence column '{sequence_col}' not found in dataset. Available columns: {list(df.columns)}" + ) + if label_col not in df.columns: + raise ValueError( + f"Label column '{label_col}' not found in dataset. Available columns: {list(df.columns)}" + ) + + # CRITICAL VALIDATION: Check for NaN labels (P0 bug fix) + nan_count = df[label_col].isna().sum() + if nan_count > 0: + raise ValueError( + f"CRITICAL: Dataset contains {nan_count} NaN labels! " + f"This will corrupt evaluation metrics. " + f"Please use the curated canonical test file (e.g., " + f"data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv with no NaNs)." + ) + + # For Jain test sets, validate expected size (allow legacy 94 + canonical 86) + if "jain" in data_path.lower() and "test" in data_path.lower(): + expected_sizes = {94, 86} + if len(df) not in expected_sizes: + raise ValueError( + f"Jain test set has {len(df)} antibodies but expected one of {sorted(expected_sizes)}. " + f"Using the wrong test set will produce invalid metrics. " + f"Please use the correct curated file (preferred: " + f"data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv)." + ) + + sequences = df[sequence_col].tolist() + labels = df[label_col].tolist() + + logger.info( + f"Loaded {len(sequences)} samples from {data_path} (sequence_col='{sequence_col}', label_col='{label_col}')" + ) + logger.info(f" Label distribution: {pd.Series(labels).value_counts().to_dict()}") + return sequences, labels diff --git a/src/antibody_training_esm/cli/testing/evaluation.py b/src/antibody_training_esm/cli/testing/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..a69dac9ca558bdbc558db64dd647b0cb5d2f8ace --- /dev/null +++ b/src/antibody_training_esm/cli/testing/evaluation.py @@ -0,0 +1,134 @@ +"""Metric calculation and model evaluation utilities.""" + +import logging +from typing import Any + +import numpy as np +from sklearn.metrics import ( + classification_report, + confusion_matrix, +) + +from antibody_training_esm.core.classifier import BinaryClassifier +from antibody_training_esm.models.artifact import EvaluationMetrics + +logger = logging.getLogger(__name__) + + +def detect_assay_type(dataset_name: str) -> str | None: + """ + Auto-detect assay type from dataset name for threshold selection + + Args: + dataset_name: Name of the dataset (e.g., "VH_only_jain", "VHH_only_harvey") + + Returns: + 'ELISA' for ELISA-based datasets (Boughter, Jain) + 'PSR' for PSR-based datasets (Harvey, Shehata) + None if unable to detect + + Notes: + Novo Nordisk (Sakhnini et al. 2025, Section 2.7): + "Antibodies characterised by the PSR assay appear to be on a different + non-specificity spectrum than that from the non-specificity ELISA assay." + + PSR datasets require threshold=0.5495 for optimal performance. + ELISA datasets use standard threshold=0.5. + """ + dataset_lower = dataset_name.lower() + + # PSR-based datasets (Harvey, Shehata) + if any(marker in dataset_lower for marker in ["harvey", "shehata"]): + return "PSR" + + # ELISA-based datasets (Boughter, Jain) + if any(marker in dataset_lower for marker in ["boughter", "jain"]): + return "ELISA" + + # Unable to detect - will use default threshold + return None + + +def evaluate_pretrained( + model: BinaryClassifier, + X: np.ndarray, + y: np.ndarray, + model_name: str, + dataset_name: str, + _metrics_list: list[str] | None = None, + threshold_override: float | None = None, +) -> dict[str, Any]: + """ + Evaluate pretrained model directly on test set (no retraining) + + Args: + model: The trained BinaryClassifier. + X: Embeddings (features). + y: True labels. + model_name: Name of the model for logging. + dataset_name: Name of the dataset for logging. + _metrics_list: List of metrics to calculate (default: all). + threshold_override: Optional manual threshold. + + Returns: + Dictionary of results including scores, predictions, and reports. + Contains 'metrics' key with EvaluationMetrics object. + """ + logger.info(f"Evaluating pretrained model {model_name} on {dataset_name}") + + # Determine threshold: manual override > auto-detect > default 0.5 + if threshold_override is not None: + # Manual override via CLI + threshold = threshold_override + logger.info(f"Using manual threshold override: {threshold}") + else: + # Auto-detect assay type from dataset name + assay_type = detect_assay_type(dataset_name) + if assay_type is not None: + threshold = model.ASSAY_THRESHOLDS[assay_type] + logger.info( + f"Auto-detected assay type: {assay_type} → threshold={threshold} " + f"(Dataset: {dataset_name})" + ) + else: + threshold = 0.5 + logger.warning( + f"Unable to auto-detect assay type for '{dataset_name}'. " + f"Using default threshold={threshold}. " + f"For optimal results, specify --threshold or use standard dataset names." + ) + + # Get predictions using the pretrained model with appropriate threshold + y_pred = model.predict( + X, threshold=threshold, assay_type=None + ) # threshold already determined + y_proba = model.predict_proba(X)[:, 1] + + # Create Pydantic metrics + eval_metrics = EvaluationMetrics.from_sklearn_metrics( + y, + y_pred, + y_proba.reshape(-1, 1) if y_proba.ndim == 1 else y_proba, + dataset_name=dataset_name, + ) + + # Calculate legacy results for compatibility with visualization tools + results = { + "metrics": eval_metrics, # Store Pydantic model + "test_scores": eval_metrics.model_dump( + exclude={"confusion_matrix", "dataset_name", "n_samples"} + ), + "predictions": {"y_true": y, "y_pred": y_pred, "y_proba": y_proba}, + "confusion_matrix": confusion_matrix(y, y_pred), + "classification_report": classification_report(y, y_pred, output_dict=True), + } + + # Log results + logger.info(f"Test results for {model_name} on {dataset_name}:") + logger.info(f" Accuracy: {eval_metrics.accuracy:.4f}") + if eval_metrics.f1 is not None: + logger.info(f" F1: {eval_metrics.f1:.4f}") + if eval_metrics.roc_auc is not None: + logger.info(f" ROC-AUC: {eval_metrics.roc_auc:.4f}") + + return results diff --git a/src/antibody_training_esm/cli/testing/tester.py b/src/antibody_training_esm/cli/testing/tester.py new file mode 100644 index 0000000000000000000000000000000000000000..edbcb67a02d50d414c3ddeb37aab93f134fa7e29 --- /dev/null +++ b/src/antibody_training_esm/cli/testing/tester.py @@ -0,0 +1,384 @@ +"Model orchestration logic." + +import json +import logging +import os +import pickle # nosec B403 +from datetime import datetime +from pathlib import Path +from typing import Any + +import numpy as np +import torch + +from antibody_training_esm.cli.testing.config import TestConfig +from antibody_training_esm.cli.testing.data import load_dataset +from antibody_training_esm.cli.testing.evaluation import evaluate_pretrained +from antibody_training_esm.cli.testing.visualization import ( + plot_confusion_matrix, + save_detailed_results, +) +from antibody_training_esm.core.classifier import BinaryClassifier +from antibody_training_esm.core.config import DEFAULT_BATCH_SIZE +from antibody_training_esm.core.directory_utils import ( + extract_classifier_shortname, + extract_model_shortname, + get_hierarchical_test_results_dir, +) +from antibody_training_esm.core.embeddings import ESMEmbeddingExtractor + + +class ModelTester: + """Model testing orchestrator""" + + def __init__(self, config: TestConfig): + self.config = config + self.logger = self._setup_logging() + self.results: dict[str, Any] = {} + self.cached_embedding_files: list[str] = [] # Track cached files for cleanup + + # Create output directory + os.makedirs(config.output_dir, exist_ok=True) + + def _setup_logging(self) -> logging.Logger: + """Setup logging configuration""" + # Create output directory if it doesn't exist + os.makedirs(self.config.output_dir, exist_ok=True) + + log_file = os.path.join( + self.config.output_dir, + f"test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log", + ) + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.FileHandler(log_file), logging.StreamHandler()], + ) + + return logging.getLogger(__name__) + + def load_model(self, model_path: str) -> BinaryClassifier: + """Load trained model from pickle file""" + self.logger.info(f"Loading model from {model_path}") + + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model file not found: {model_path}") + + with open(model_path, "rb") as f: + model = pickle.load(f) # nosec B301 + + if not isinstance(model, BinaryClassifier): + raise ValueError(f"Expected BinaryClassifier, got {type(model)}") + + # Update device if different from config + if ( + hasattr(model, "embedding_extractor") + and model.embedding_extractor.device != self.config.device + ): + self.logger.warning( + f"Device mismatch: model trained on {model.embedding_extractor.device}, " + f"test config specifies {self.config.device}. Recreating extractor..." + ) + + # CRITICAL: Explicit cleanup to prevent semaphore leaks (P0 bug fix) + old_device = str(model.embedding_extractor.device) + old_extractor = model.embedding_extractor + + # Delete old extractor before creating new one + del model.embedding_extractor + del old_extractor + + # Clear device-specific GPU cache + if old_device.startswith("cuda"): + torch.cuda.empty_cache() + elif old_device.startswith("mps"): + torch.mps.empty_cache() + + self.logger.info(f"Cleaned up old extractor on {old_device}") + + # NOW create new extractor (no leak) + batch_size = getattr(model, "batch_size", DEFAULT_BATCH_SIZE) + revision = getattr(model, "revision", "main") + model.embedding_extractor = ESMEmbeddingExtractor( + model.model_name, self.config.device, batch_size, revision=revision + ) + model.device = self.config.device + + self.logger.info(f"Created new extractor on {self.config.device}") + + # Update batch_size if different from config + if ( + hasattr(model, "embedding_extractor") + and model.embedding_extractor.batch_size != self.config.batch_size + ): + self.logger.info( + f"Updating batch_size from {model.embedding_extractor.batch_size} to {self.config.batch_size}" + ) + model.embedding_extractor.batch_size = self.config.batch_size + + self.logger.info( + f"Model loaded successfully: {model_path} on device: {model.embedding_extractor.device}" + ) + return model + + def embed_sequences( + self, + sequences: list[str], + model: BinaryClassifier, + dataset_name: str, + output_dir: str, + ) -> np.ndarray: + """Extract embeddings for sequences using the model's embedding extractor""" + # Ensure output directory exists before file I/O + os.makedirs(output_dir, exist_ok=True) + + cache_file = os.path.join(output_dir, f"{dataset_name}_test_embeddings.pkl") + + # Track this file for cleanup + if cache_file not in self.cached_embedding_files: + self.cached_embedding_files.append(cache_file) + + # Try to load from cache + if os.path.exists(cache_file): + try: + self.logger.info(f"Loading cached embeddings from {cache_file}") + with open(cache_file, "rb") as f: + embeddings: np.ndarray = pickle.load(f) # nosec B301 + + # Validate shape and type + if not isinstance(embeddings, np.ndarray): + raise ValueError(f"Invalid cache data type: {type(embeddings)}") + if embeddings.ndim != 2: + raise ValueError(f"Invalid embedding shape: {embeddings.shape}") + + if len(embeddings) == len(sequences): + self.logger.info(f"Loaded {len(embeddings)} cached embeddings") + return embeddings + else: + self.logger.warning( + "Cached embeddings size mismatch, recomputing..." + ) + + except (pickle.UnpicklingError, EOFError, ValueError, AttributeError) as e: + self.logger.warning( + f"Failed to load cached embeddings from {cache_file}: {e}. " + "Recomputing embeddings..." + ) + # Fall through to recomputation below + + # Extract embeddings + self.logger.info(f"Extracting embeddings for {len(sequences)} sequences...") + embeddings = model.embedding_extractor.extract_batch_embeddings(sequences) + + # Cache embeddings + with open(cache_file, "wb") as f: + pickle.dump(embeddings, f) + self.logger.info(f"Embeddings cached to {cache_file}") + + return embeddings + + def cleanup_cached_embeddings(self) -> None: + """Delete cached embedding files""" + self.logger.info("Cleaning up cached embedding files...") + for cache_file in self.cached_embedding_files: + if os.path.exists(cache_file): + try: + os.remove(cache_file) + self.logger.info(f"Deleted cached embeddings: {cache_file}") + except Exception as e: + self.logger.warning(f"Failed to delete {cache_file}: {e}") + + def _compute_output_directory( + self, + model_path: str | None, + dataset_name: str, + ) -> str: + """Compute output directory (hierarchical if model config available, else flat).""" + if model_path is None: + self.logger.warning("No model path provided, using flat output structure") + return self.config.output_dir + + # Try to load model config JSON + model_config_path = ( + Path(model_path) + .with_suffix("") + .with_name(Path(model_path).stem + "_config.json") + ) + + if not model_config_path.exists(): + self.logger.info( + f"Model config not found at {model_config_path}, using flat output structure" + ) + return self.config.output_dir + + try: + with open(model_config_path) as f: + model_config = json.load(f) + + model_name = model_config.get("model_name") or model_config.get( + "esm_model", "" + ) + if not model_name: + raise ValueError("Model config missing 'model_name' or 'esm_model'") + + classifier_config = model_config.get("classifier", {}) + + # Use shared utility for hierarchical path generation + hierarchical_path = get_hierarchical_test_results_dir( + base_dir=self.config.output_dir, + model_name=model_name, + classifier_config=classifier_config, + dataset_name=dataset_name, + ) + + # Extract shortnames for logging + model_short = extract_model_shortname(model_name) + classifier_short = extract_classifier_shortname(classifier_config) + + self.logger.info( + f"Using hierarchical output: {hierarchical_path} " + f"(model={model_short}, classifier={classifier_short})" + ) + return str(hierarchical_path) + + except (json.JSONDecodeError, KeyError, ValueError) as e: + self.logger.warning( + f"Could not determine hierarchical path from model config: {e}. " + "Using flat structure." + ) + return self.config.output_dir + + def run_comprehensive_test(self) -> dict[str, dict[str, Any]]: + """Run testing pipeline""" + self.logger.info("Starting model testing") + self.logger.info(f"Models to test: {self.config.model_paths}") + self.logger.info(f"Datasets to test: {self.config.data_paths}") + + all_results = {} + failed_datasets = [] + failed_models = [] + + try: + # Test each dataset + for data_path in self.config.data_paths: + dataset_name = Path(data_path).stem + self.logger.info(f"\n{'=' * 60}") + self.logger.info(f"Testing on dataset: {dataset_name}") + self.logger.info(f"{'=' * 60}") + + # Load dataset + try: + sequences, labels_list = load_dataset(data_path, self.config) + labels: np.ndarray = np.array(labels_list) + except Exception as e: + self.logger.error(f"Failed to load dataset {data_path}: {e}") + failed_datasets.append((dataset_name, str(e))) + continue + + dataset_results = {} + + # Test each model + for model_path in self.config.model_paths: + model_name = Path(model_path).stem + self.logger.info(f"\nTesting model: {model_name}") + + output_dir_for_dataset = self._compute_output_directory( + model_path, dataset_name + ) + + try: + # Load model + model = self.load_model(model_path) + + # Extract embeddings + X_embedded = self.embed_sequences( + sequences, + model, + f"{dataset_name}_{model_name}", + output_dir_for_dataset, + ) + + # Evaluation (delegated to evaluation module) + test_results = evaluate_pretrained( + model, + X_embedded, + labels, + model_name, + dataset_name, + self.config.metrics, + self.config.threshold, + ) + dataset_results[model_name] = test_results + + # Visualization (delegated to visualization module) + single_model_results = {model_name: test_results} + plot_confusion_matrix( + single_model_results, + dataset_name, + output_dir=output_dir_for_dataset, + ) + save_detailed_results( + single_model_results, + dataset_name, + self.config.__dict__, + output_dir=output_dir_for_dataset, + save_predictions=self.config.save_predictions, + ) + + except Exception as e: + self.logger.error(f"Failed to test model {model_path}: {e}") + failed_models.append((f"{dataset_name}_{model_name}", str(e))) + continue + + # Generate aggregated multi-model report + if dataset_results: + aggregated_output_dir = self.config.output_dir + self.logger.info( + f"Generating aggregated multi-model report for {dataset_name} " + f"in {aggregated_output_dir}" + ) + + plot_confusion_matrix( + dataset_results, + dataset_name, + output_dir=aggregated_output_dir, + ) + save_detailed_results( + dataset_results, + dataset_name, + self.config.__dict__, + output_dir=aggregated_output_dir, + save_predictions=self.config.save_predictions, + ) + + all_results[dataset_name] = dataset_results + + # Check if all tests failed + if not all_results: + error_msg = "All tests failed:\n" + if failed_datasets: + error_msg += ( + f" Failed datasets: {[name for name, _ in failed_datasets]}\n" + ) + if failed_models: + error_msg += ( + f" Failed models: {[name for name, _ in failed_models]}\n" + ) + raise RuntimeError(error_msg + "No successful test results to report.") + + if failed_datasets or failed_models: + self.logger.warning( + f"\nSome tests failed (datasets: {len(failed_datasets)}, " + f"models: {len(failed_models)}). Check logs for details." + ) + + self.results = all_results + self.logger.info( + f"\nTesting completed. Results saved to: {self.config.output_dir}" + ) + + finally: + self.cleanup_cached_embeddings() + + return all_results diff --git a/src/antibody_training_esm/cli/testing/visualization.py b/src/antibody_training_esm/cli/testing/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..bb96d0f4108593a941288dac472301ce4748e54a --- /dev/null +++ b/src/antibody_training_esm/cli/testing/visualization.py @@ -0,0 +1,127 @@ +"""Plotting and result serialization utilities.""" + +import logging +import os +from datetime import datetime +from typing import Any + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +import yaml + +# Configure matplotlib +plt.style.use("seaborn-v0_8" if "seaborn-v0_8" in plt.style.available else "default") +sns.set_palette("husl") + +logger = logging.getLogger(__name__) + + +def plot_confusion_matrix( + results: dict[str, dict[str, Any]], + dataset_name: str, + output_dir: str, +) -> None: + """ + Create confusion matrix visualization (individual files per model). + + Args: + results: Dictionary mapping model names to result dictionaries. + dataset_name: Name of the dataset. + output_dir: Directory to save plots. + """ + os.makedirs(output_dir, exist_ok=True) + + logger.info(f"Creating confusion matrices for {dataset_name} in {output_dir}") + + # Create individual confusion matrix for each model to prevent overrides + for model_name, model_results in results.items(): + if "confusion_matrix" not in model_results: + logger.warning(f"No confusion matrix found for {model_name}, skipping plot") + continue + + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + cm = model_results["confusion_matrix"] + sns.heatmap( + cm, + annot=True, + fmt="d", + cmap="Blues", + xticklabels=["Negative", "Positive"], + yticklabels=["Negative", "Positive"], + ax=ax, + ) + ax.set_title(f"Confusion Matrix - {model_name} on {dataset_name}") + ax.set_ylabel("True Label") + ax.set_xlabel("Predicted Label") + + plt.tight_layout() + + # Save plot with model name to prevent overrides when testing multiple backbones + plot_file = os.path.join( + output_dir, + f"confusion_matrix_{model_name}_{dataset_name}.png", + ) + plt.savefig(plot_file, dpi=300, bbox_inches="tight") + plt.close() + + logger.info(f"Confusion matrix saved to {plot_file}") + + +def save_detailed_results( + results: dict[str, dict[str, Any]], + dataset_name: str, + config_dict: dict[str, Any], + output_dir: str, + save_predictions: bool = True, +) -> None: + """ + Save detailed results to files (individual files per model). + + Args: + results: Dictionary mapping model names to result dictionaries. + dataset_name: Name of the dataset. + config_dict: Configuration dictionary to embed in YAML. + output_dir: Directory to save results. + save_predictions: Whether to save prediction CSVs. + """ + os.makedirs(output_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Save individual YAML for each model to prevent overrides + for model_name, model_results in results.items(): + results_file = os.path.join( + output_dir, + f"detailed_results_{model_name}_{dataset_name}_{timestamp}.yaml", + ) + with open(results_file, "w") as f: + yaml.dump( + { + "dataset": dataset_name, + "model": model_name, + "config": config_dict, + "results": model_results, + }, + f, + default_flow_style=False, + ) + logger.info(f"Detailed results saved to {results_file}") + + # Save predictions if requested + if save_predictions: + for model_name, model_results in results.items(): + if "predictions" in model_results: + pred_file = os.path.join( + output_dir, + f"predictions_{model_name}_{dataset_name}_{timestamp}.csv", + ) + pred_df = pd.DataFrame( + { + "y_true": model_results["predictions"]["y_true"], + "y_pred": model_results["predictions"]["y_pred"], + "y_proba": model_results["predictions"]["y_proba"], + } + ) + pred_df.to_csv(pred_file, index=False) + logger.info(f"Predictions saved to {pred_file}") diff --git a/src/antibody_training_esm/cli/train.py b/src/antibody_training_esm/cli/train.py new file mode 100644 index 0000000000000000000000000000000000000000..547d41f7f1adc0da548822525daf6cd5379ee370 --- /dev/null +++ b/src/antibody_training_esm/cli/train.py @@ -0,0 +1,42 @@ +""" +Training CLI - Hydra Entry Point + +Professional command-line interface for antibody model training. +Uses Hydra for configuration management and supports dynamic overrides. + +Usage: + # Default config + antibody-train + + # With overrides + antibody-train hardware.device=cuda training.batch_size=16 + + # Multi-run sweep + antibody-train --multirun classifier.C=0.1,1.0,10.0 + + # Help + antibody-train --help +""" + +from antibody_training_esm.core.trainer import main as hydra_main + + +def main() -> None: + """ + Main entry point for training CLI + + Delegates to Hydra-decorated main() in core.trainer. + This provides automatic config composition, override support, + and multi-run sweeps. + + Note: + This function does not return an exit code (Hydra handles that). + Use try/except at a higher level if you need custom error handling. + """ + # Delegate to Hydra entry point + # Hydra automatically parses sys.argv and handles all CLI logic + hydra_main() + + +if __name__ == "__main__": + main() diff --git a/src/antibody_training_esm/conf/__init__.py b/src/antibody_training_esm/conf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2495737b0af206e4c0b97a01b07b1d530c317f80 --- /dev/null +++ b/src/antibody_training_esm/conf/__init__.py @@ -0,0 +1,9 @@ +""" +Hydra configuration package + +Contains YAML configs and structured config schemas. +""" + +# Import config_schema to execute ConfigStore registrations +# This MUST run at import time for structured configs to work +from . import config_schema # noqa: F401 diff --git a/src/antibody_training_esm/conf/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/conf/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5444224ad9f8d3ebeef92e9144a1c31f0bc43a58 Binary files /dev/null and b/src/antibody_training_esm/conf/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/conf/__pycache__/config_schema.cpython-312.pyc b/src/antibody_training_esm/conf/__pycache__/config_schema.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c48e28bde48c7dd875a7580830c5a217bb96ecf Binary files /dev/null and b/src/antibody_training_esm/conf/__pycache__/config_schema.cpython-312.pyc differ diff --git a/src/antibody_training_esm/conf/classifier/logreg.yaml b/src/antibody_training_esm/conf/classifier/logreg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dc80c2fd706d7b9e58ed9956ccec90bdc445a40f --- /dev/null +++ b/src/antibody_training_esm/conf/classifier/logreg.yaml @@ -0,0 +1,12 @@ +type: logistic_regression +C: 1.0 +penalty: l2 +solver: lbfgs +max_iter: 1000 +random_state: ${training.random_state} +class_weight: null +cv_folds: 10 +stratify: true +path: null +# Optional path to the JSON config file (for .npz models) +config_path: null diff --git a/src/antibody_training_esm/conf/classifier/xgboost.yaml b/src/antibody_training_esm/conf/classifier/xgboost.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0010f4bfbb088ca9d654c32711cc441a6e94ecdf --- /dev/null +++ b/src/antibody_training_esm/conf/classifier/xgboost.yaml @@ -0,0 +1,14 @@ +type: xgboost +n_estimators: 100 +max_depth: 6 +learning_rate: 0.3 +subsample: 1.0 +colsample_bytree: 1.0 +reg_alpha: 0.0 +reg_lambda: 1.0 +random_state: ${training.random_state} +objective: binary:logistic +cv_folds: 10 +stratify: true +path: null +config_path: null diff --git a/src/antibody_training_esm/conf/config.yaml b/src/antibody_training_esm/conf/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c1cc8eda7dbf6c9dfdcbd39fc894db5117370d69 --- /dev/null +++ b/src/antibody_training_esm/conf/config.yaml @@ -0,0 +1,36 @@ +defaults: + - model: esm1v + - classifier: logreg + - data: boughter_jain + - hardware: default + - hydra: default + - _self_ + +# Training settings (matches current trainer.py requirements) +training: + # Cross-validation + n_splits: 10 + random_state: 42 + stratify: true + + # Evaluation metrics (list of metrics to compute) + metrics: [accuracy, precision, recall, f1, roc_auc] + + # Model saving + save_model: true + model_name: boughter_vh_esm1v_logreg + model_save_dir: ./experiments/checkpoints + + # Logging (Hydra-aware: relative to Hydra output dir, or logs/ in legacy mode) + log_level: INFO + log_file: training.log + + # Performance optimization + batch_size: 8 + num_workers: 4 + +# Experiment metadata (Hydra manages output dirs) +experiment: + name: novo_replication + description: "Train ESM-1v VH-based LogisticReg on Boughter, test on Jain" + tags: [baseline, esm1v, logreg] diff --git a/src/antibody_training_esm/conf/config_schema.py b/src/antibody_training_esm/conf/config_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebacb5255bffe183849f4ac1f8927c5081a90a1 --- /dev/null +++ b/src/antibody_training_esm/conf/config_schema.py @@ -0,0 +1,142 @@ +""" +Structured configuration schemas for Hydra + +Type-safe configuration using dataclasses with full field coverage +validated against current trainer.py requirements. +""" + +from dataclasses import dataclass, field + +# ConfigStore import removed - no longer needed since registrations are commented out +# from hydra.core.config_store import ConfigStore +from omegaconf import MISSING + + +@dataclass +class ModelConfig: + """ESM model configuration (matches current model config structure)""" + + name: str = "facebook/esm1v_t33_650M_UR90S_1" + revision: str = "main" + device: str = MISSING # Provided by YAML interpolation ${hardware.device} + + +@dataclass +class ClassifierConfig: + """Classifier head configuration (matches current classifier config)""" + + type: str = "logistic_regression" + C: float = 1.0 + penalty: str = "l2" + solver: str = "lbfgs" + max_iter: int = 1000 + random_state: int = ( + MISSING # Provided by YAML interpolation ${training.random_state} + ) + class_weight: str | None = None + cv_folds: int = 10 + stratify: bool = True + + +@dataclass +class DataConfig: + """Dataset configuration (ALL fields used by loaders.py + trainer.py)""" + + # REQUIRED by loaders.py + source: str = "local" + train_file: str = MISSING # Required + test_file: str = MISSING # Required + sequence_column: str = "sequence" + label_column: str = "label" + + # REQUIRED by trainer.py + embeddings_cache_dir: str = "./experiments/cache" + + # Optional fields + dataset_name: str = "boughter_vh" + max_sequence_length: int = 1024 + save_embeddings: bool = True + + # Fragment metadata (testing only) + train_fragment: str = "VH" + test_fragment: str = "VH" + test_assay: str = "ELISA" + test_threshold: float = 0.5 + + +@dataclass +class TrainingConfig: + """Training hyperparameters (ALL fields used by trainer.py)""" + + # Cross-validation + n_splits: int = 10 + random_state: int = 42 + stratify: bool = True + + # Evaluation metrics + metrics: list[str] = field( + default_factory=lambda: ["accuracy", "precision", "recall", "f1", "roc_auc"] + ) + + # Model saving + save_model: bool = True + model_name: str = "boughter_vh_esm1v_logreg" + model_save_dir: str = "./experiments/checkpoints" + + # Logging (Hydra-aware: relative to Hydra output dir, or logs/ in legacy mode) + log_level: str = "INFO" + log_file: str = "training.log" # Routes to logs/ dir in legacy mode, Hydra output dir in Hydra mode + + # Performance optimization + batch_size: int = 8 + num_workers: int = 4 + + +@dataclass +class HardwareConfig: + """Hardware settings""" + + device: str = "mps" + gpu_memory_fraction: float = 0.8 + clear_cache_frequency: int = 100 + + +@dataclass +class ExperimentConfig: + """Experiment metadata""" + + name: str = "novo_replication" + description: str = "Train ESM-1v VH-based LogisticReg on Boughter, test on Jain" + tags: list[str] = field(default_factory=lambda: ["baseline", "esm1v", "logreg"]) + + +@dataclass +class Config: + """Root configuration (complete schema matching current trainer.py)""" + + model: ModelConfig = field(default_factory=ModelConfig) + classifier: ClassifierConfig = field(default_factory=ClassifierConfig) + data: DataConfig = field(default_factory=DataConfig) + training: TrainingConfig = field(default_factory=TrainingConfig) + hardware: HardwareConfig = field(default_factory=HardwareConfig) + experiment: ExperimentConfig = field(default_factory=ExperimentConfig) + + +# ConfigStore registrations REMOVED to fix CLI override bug +# +# Root cause: Registering structured configs with the same names as YAML files +# causes Hydra to prefer ConfigStore over YAML when using package-based config +# loading (which the console script does). This breaks config group overrides. +# +# Known issue: Hydra structured configs strictly validate keys. +# Overrides adding new keys require proper schema definition or +key syntax with strict mode disabled.# See: https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching +# +# The dataclasses above are kept for type hints and validation in code, but are +# no longer registered with ConfigStore. This allows YAML files to be the single +# source of truth for configuration. +# +# cs = ConfigStore.instance() +# cs.store(name="config", node=Config) +# cs.store(group="model", name="esm1v", node=ModelConfig) +# cs.store(group="classifier", name="logreg", node=ClassifierConfig) +# cs.store(group="data", name="boughter_jain", node=DataConfig) diff --git a/src/antibody_training_esm/conf/data/boughter_jain.yaml b/src/antibody_training_esm/conf/data/boughter_jain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e270b3741342c776cc7fcb2d52ed91e5fd270057 --- /dev/null +++ b/src/antibody_training_esm/conf/data/boughter_jain.yaml @@ -0,0 +1,23 @@ +# Data source (matches current loaders.py requirements) +source: local +dataset_name: boughter_vh + +# File paths +train_file: data/train/boughter/canonical/VH_only_boughter_training.csv +test_file: data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv + +# Data format options (required by loaders.py) +# Jain canonical parity file uses 'vh_sequence'; align config to avoid column errors +sequence_column: sequence +label_column: label +max_sequence_length: 1024 + +# Embedding caching (required by trainer.py) +save_embeddings: true +embeddings_cache_dir: ./experiments/cache + +# Fragment metadata (for testing only) +train_fragment: VH +test_fragment: VH +test_assay: ELISA +test_threshold: 0.5 diff --git a/src/antibody_training_esm/conf/hardware/default.yaml b/src/antibody_training_esm/conf/hardware/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c4a63cc7d20bdd8e765693ef6bac72a4753fec9f --- /dev/null +++ b/src/antibody_training_esm/conf/hardware/default.yaml @@ -0,0 +1,5 @@ +# Hardware configuration +# Default to MPS for macOS performance (training/testing); Gradio app handles stability fallback +device: mps +gpu_memory_fraction: 0.8 +clear_cache_frequency: 100 diff --git a/src/antibody_training_esm/conf/hydra/default.yaml b/src/antibody_training_esm/conf/hydra/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4d0dba6fe2eed084a36558df34808d78df76c229 --- /dev/null +++ b/src/antibody_training_esm/conf/hydra/default.yaml @@ -0,0 +1,10 @@ +# Hydra output directory management +run: + dir: experiments/runs/${experiment.name}/${now:%Y-%m-%d_%H-%M-%S} + +sweep: + dir: experiments/runs/sweeps/${experiment.name} + subdir: ${hydra.job.num} + +job: + chdir: false # Don't change working directory diff --git a/src/antibody_training_esm/conf/model/esm1v.yaml b/src/antibody_training_esm/conf/model/esm1v.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ef0e9c99b45b4d46df929c1cac48063af8d3374e --- /dev/null +++ b/src/antibody_training_esm/conf/model/esm1v.yaml @@ -0,0 +1,4 @@ +name: facebook/esm1v_t33_650M_UR90S_1 +revision: main +# Default to CPU for stability on macOS; override with hardware.device or CLI if desired +device: ${hardware.device} diff --git a/src/antibody_training_esm/conf/model/esm2_650m.yaml b/src/antibody_training_esm/conf/model/esm2_650m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..16e5384fffa62c3527746aad57b38ac056b4d5cd --- /dev/null +++ b/src/antibody_training_esm/conf/model/esm2_650m.yaml @@ -0,0 +1,3 @@ +name: facebook/esm2_t33_650M_UR50D +revision: main +device: ${hardware.device} diff --git a/src/antibody_training_esm/conf/predict.yaml b/src/antibody_training_esm/conf/predict.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38788b93a91c2b60038d5be65f72ed4262c26578 --- /dev/null +++ b/src/antibody_training_esm/conf/predict.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +defaults: + - /model: esm1v + - /classifier: logreg + - /hardware: default + - _self_ + +input_file: null +output_file: "predictions.csv" +sequence_column: "sequence" +assay_type: null # Options: "PSR", "ELISA", or null +threshold: 0.5 # Ignored if assay_type is set + +gradio: + server_name: "0.0.0.0" + server_port: 7860 + share: false + queue: + concurrency_limit: 2 # Based on 8GB VRAM (3GB per ESM-1v inference) + max_size: 10 # Prevents unbounded queue growth + log_level: INFO + +hydra: + job: + chdir: False diff --git a/src/antibody_training_esm/conf/testing/jain_p5e_s2.yaml b/src/antibody_training_esm/conf/testing/jain_p5e_s2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ea9b839c19848f5ebc3ea4ffe06e8752b60e72c1 --- /dev/null +++ b/src/antibody_training_esm/conf/testing/jain_p5e_s2.yaml @@ -0,0 +1,7 @@ +model_paths: [experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl] +data_paths: [data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv] +sequence_column: vh_sequence +label_column: label +output_dir: experiments/benchmarks +device: cpu +batch_size: 8 diff --git a/src/antibody_training_esm/core/__init__.py b/src/antibody_training_esm/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a168454dbb7ac5d84681ce5b307bb79f9ab434 --- /dev/null +++ b/src/antibody_training_esm/core/__init__.py @@ -0,0 +1,19 @@ +""" +Core ML Module + +Professional ML components for antibody classification: +- ESM embedding extraction +- Binary classification +- Training pipelines +- Model serialization (pickle + NPZ+JSON) +""" + +from antibody_training_esm.core.classifier import BinaryClassifier +from antibody_training_esm.core.embeddings import ESMEmbeddingExtractor +from antibody_training_esm.core.trainer import load_model_from_npz + +__all__ = [ + "BinaryClassifier", + "ESMEmbeddingExtractor", + "load_model_from_npz", +] diff --git a/src/antibody_training_esm/core/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/core/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80730656ea830eb4a8e6db0e240fbe3c64710b34 Binary files /dev/null and b/src/antibody_training_esm/core/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/__pycache__/classifier.cpython-312.pyc b/src/antibody_training_esm/core/__pycache__/classifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72806c6b9bd94096bc60f8410d5a2b16bf1bd11c Binary files /dev/null and b/src/antibody_training_esm/core/__pycache__/classifier.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/__pycache__/classifier_factory.cpython-312.pyc b/src/antibody_training_esm/core/__pycache__/classifier_factory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4263dd884bcbfcb132885dc886528fccc8fbcee7 Binary files /dev/null and b/src/antibody_training_esm/core/__pycache__/classifier_factory.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/__pycache__/classifier_strategy.cpython-312.pyc b/src/antibody_training_esm/core/__pycache__/classifier_strategy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..631a786eb587e26b5ad6bf5554eaaa92787a35c2 Binary files /dev/null and b/src/antibody_training_esm/core/__pycache__/classifier_strategy.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/__pycache__/config.cpython-312.pyc b/src/antibody_training_esm/core/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d4f737f00398e8e6254081eb91cf173958d6a80 Binary files /dev/null and b/src/antibody_training_esm/core/__pycache__/config.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/__pycache__/directory_utils.cpython-312.pyc b/src/antibody_training_esm/core/__pycache__/directory_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1dea107c0fd3604f79be8bd6e7e42ac9229e150 Binary files /dev/null and b/src/antibody_training_esm/core/__pycache__/directory_utils.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/__pycache__/embeddings.cpython-312.pyc b/src/antibody_training_esm/core/__pycache__/embeddings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfdf4caa6dba8408b92c780aa2c29567da873526 Binary files /dev/null and b/src/antibody_training_esm/core/__pycache__/embeddings.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/__pycache__/prediction.cpython-312.pyc b/src/antibody_training_esm/core/__pycache__/prediction.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45989b8c55435858b8b5867785decf1b8a16bc7d Binary files /dev/null and b/src/antibody_training_esm/core/__pycache__/prediction.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/__pycache__/regressor.cpython-312.pyc b/src/antibody_training_esm/core/__pycache__/regressor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52e43ffbe6d18e69c28aa440d92584fc24a72a65 Binary files /dev/null and b/src/antibody_training_esm/core/__pycache__/regressor.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/__pycache__/trainer.cpython-312.pyc b/src/antibody_training_esm/core/__pycache__/trainer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90ee1bd503279e03118fbfb451272700c9669b88 Binary files /dev/null and b/src/antibody_training_esm/core/__pycache__/trainer.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/classifier.py b/src/antibody_training_esm/core/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..1d815a21407c64acc768c3b50869a34dc712bd50 --- /dev/null +++ b/src/antibody_training_esm/core/classifier.py @@ -0,0 +1,350 @@ +""" +Binary Classifier Module + +Professional binary classifier for antibody sequences using ESM-1V embeddings. +Includes sklearn compatibility, assay-specific thresholds, and model serialization. +""" + +import logging +from typing import Any + +import numpy as np + +from antibody_training_esm.core.classifier_factory import create_classifier +from antibody_training_esm.core.classifier_strategy import ClassifierStrategy +from antibody_training_esm.core.config import DEFAULT_BATCH_SIZE +from antibody_training_esm.core.embeddings import ESMEmbeddingExtractor + +logger = logging.getLogger(__name__) + + +class BinaryClassifier: + """Binary classifier for protein sequences using ESM-1V embeddings""" + + # sklearn 1.7+ requires explicit estimator type for cross_val_score + # This tells sklearn's validation logic that we're a classifier, not a regressor + _estimator_type = "classifier" + + # Assay-specific thresholds (Novo Nordisk methodology) + ASSAY_THRESHOLDS = { + "ELISA": 0.5, # Training data type (Boughter, Jain) + "PSR": 0.5495, # PSR assay type (Shehata, Harvey) - EXACT Novo parity + } + + def __init__(self, params: dict[str, Any] | None = None, **kwargs: Any): + """ + Initialize the binary classifier + + Args: + params: Dictionary containing classifier parameters (legacy API) + **kwargs: Individual parameters (for sklearn compatibility) + + Notes: + Supports both dict-based (legacy) and kwargs-based (sklearn) initialization + """ + # Support both dict-based (legacy) and kwargs-based (sklearn) initialization + if params is None: + params = kwargs + + # Validate required parameters (universal across all strategies) + # Note: max_iter is LogReg-specific, removed from required params + REQUIRED_PARAMS = ["random_state", "model_name", "device"] + missing = [p for p in REQUIRED_PARAMS if p not in params] + if missing: + raise ValueError( + f"Missing required parameters: {missing}. " + f"BinaryClassifier requires: {REQUIRED_PARAMS}" + ) + + random_state = params["random_state"] + batch_size = params.get( + "batch_size", DEFAULT_BATCH_SIZE + ) # Default if not provided + revision = params.get("revision", "main") # HF model revision (default: "main") + + self.embedding_extractor = ESMEmbeddingExtractor( + params["model_name"], params["device"], batch_size, revision=revision + ) + + # Use factory to create classifier strategy (supports LogReg, XGBoost, etc.) + self.classifier: ClassifierStrategy = create_classifier(params) + + logger.info( + "Classifier initialized: type=%s, params=%s", + params.get("type", "logistic_regression"), + self.classifier.get_params(), + ) + + # Store hyperparameters for recreation and sklearn compatibility + self.random_state = random_state + self.is_fitted = False + self.device = self.embedding_extractor.device + self.model_name = params["model_name"] + self.batch_size = batch_size + self.revision = revision # Store HF model revision for reproducibility + + # Store all params for sklearn compatibility + self._params = params + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """ + Get parameters for sklearn compatibility (required for cross_val_score) + + Args: + deep: If True, return parameters for sub-estimators + + Returns: + Dictionary of parameters (embedding params + classifier params) + """ + # Merge embedding extractor params + classifier strategy params + params = { + "random_state": self.random_state, + "model_name": self.model_name, + "device": self.device, + "batch_size": self.batch_size, + "revision": self.revision, + } + # Add classifier-specific params from strategy + params.update(self.classifier.get_params(deep=deep)) + return params + + def set_params(self, **params: Any) -> "BinaryClassifier": + """ + Set parameters for sklearn compatibility (required for cross_val_score) + + Args: + **params: Parameters to set + + Returns: + self + + Notes: + This method updates parameters without destroying fitted state. + If model_name or device changes, the embedding extractor is recreated. + If classifier type changes, the classifier strategy is recreated. + """ + # Update internal params dict + self._params.update(params) + + # Track if we need to recreate components + needs_extractor_reload = False + needs_classifier_reload = False + + # Check which components need reloading + embedding_params = {"model_name", "device", "batch_size", "revision"} + if any(key in params for key in embedding_params): + needs_extractor_reload = True + # Update instance attributes + self.model_name = self._params.get("model_name", self.model_name) + self.device = self._params.get("device", self.device) + self.batch_size = self._params.get("batch_size", self.batch_size) + self.revision = self._params.get("revision", self.revision) + + if "type" in params: + needs_classifier_reload = True + + # Update random_state (used by both components) + if "random_state" in params: + self.random_state = params["random_state"] + + # Recreate embedding extractor if needed + if needs_extractor_reload: + logger.info( + f"Recreating embedding extractor: model_name={self.model_name}, " + f"device={self.device}, batch_size={self.batch_size}" + ) + self.embedding_extractor = ESMEmbeddingExtractor( + self.model_name, self.device, self.batch_size, revision=self.revision + ) + + # Recreate classifier strategy if type changed + if needs_classifier_reload: + logger.info(f"Recreating classifier: type={params.get('type')}") + self.classifier = create_classifier(self._params) + self.is_fitted = False # New classifier is unfitted + else: + # Update existing classifier params (e.g., C, penalty, solver) + classifier_params = { + k: v + for k, v in params.items() + if k not in embedding_params and k not in {"random_state", "type"} + } + if classifier_params: + # For LogReg and other sklearn estimators, update attributes directly + for key, value in classifier_params.items(): + if hasattr(self.classifier, key): + setattr(self.classifier, key, value) + # Also update underlying sklearn classifier + if hasattr(self.classifier, "classifier") and hasattr( + self.classifier.classifier, key + ): + setattr(self.classifier.classifier, key, value) + + return self + + def fit(self, X: np.ndarray, y: np.ndarray) -> None: + """ + Fit the classifier to the data + + Args: + X: Array of ESM-1V embeddings + y: Array of labels + """ + # Fit the classifier directly on embeddings (no scaling per Novo methodology) + self.classifier.fit(X, y) + self.is_fitted = True + + # sklearn 1.7+ requires classes_ attribute for cross_val_score compatibility + self.classes_ = self.classifier.classes_ + + logger.info(f"Classifier fitted on {len(X)} samples") + + def predict( + self, X: np.ndarray, threshold: float = 0.5, assay_type: str | None = None + ) -> np.ndarray: + """ + Predict labels for the data with optional assay-specific thresholds + + Args: + X: Array of ESM-1V embeddings + threshold: Decision threshold for classification (default: 0.5) + Ignored if assay_type is specified + assay_type: Type of assay for dataset-specific thresholds. Options: + - 'ELISA': Use threshold=0.5 (for Jain, Boughter datasets) + - 'PSR': Use threshold=0.5495 (for Shehata, Harvey datasets) + - None: Use the threshold parameter + + Returns: + Predicted labels + + Raises: + ValueError: If classifier is not fitted or assay_type is unknown + + Notes: + The model was trained on ELISA data (Boughter dataset). Different assay types + measure different "spectrums" of non-specificity (Sakhnini et al. 2025, Section 2.7). + Use assay_type='PSR' for PSR-based datasets to get calibrated predictions. + """ + if not self.is_fitted: + raise ValueError("Classifier must be fitted before making predictions") + + # Determine which threshold to use + if assay_type is not None: + if assay_type not in self.ASSAY_THRESHOLDS: + raise ValueError( + f"Unknown assay_type '{assay_type}'. Must be one of: {list(self.ASSAY_THRESHOLDS.keys())}" + ) + threshold = self.ASSAY_THRESHOLDS[assay_type] + + # Get probabilities and apply threshold + probabilities = self.classifier.predict_proba(X) + predictions: np.ndarray = (probabilities[:, 1] > threshold).astype(int) + + return predictions + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + """ + Predict class probabilities for the data + + Args: + X: Array of ESM-1V embeddings + + Returns: + Predicted probabilities + + Raises: + ValueError: If classifier is not fitted + """ + if not self.is_fitted: + raise ValueError("Classifier must be fitted before making predictions") + + result: np.ndarray = self.classifier.predict_proba(X) + return result + + def score(self, X: np.ndarray, y: np.ndarray) -> float: + """ + Return the mean accuracy on the given test data and labels + + Args: + X: Array of ESM-1V embeddings + y: Array of true labels + + Returns: + Mean accuracy + + Raises: + ValueError: If classifier is not fitted + """ + if not self.is_fitted: + raise ValueError("Classifier must be fitted before scoring") + + score: float = self.classifier.score(X, y) + return score + + # ======================================================================== + # Backward Compatibility Properties (delegate to strategy) + # ======================================================================== + + @property + def C(self) -> float: + """Regularization parameter (LogReg only, for backward compatibility)""" + return getattr(self.classifier, "C", 1.0) + + @property + def penalty(self) -> str: + """Regularization type (LogReg only, for backward compatibility)""" + return getattr(self.classifier, "penalty", "l2") + + @property + def solver(self) -> str: + """Optimization algorithm (LogReg only, for backward compatibility)""" + return getattr(self.classifier, "solver", "lbfgs") + + @property + def class_weight(self) -> Any: + """Class weights (LogReg only, for backward compatibility)""" + return getattr(self.classifier, "class_weight", None) + + @property + def max_iter(self) -> int: + """Maximum iterations (LogReg only, for backward compatibility)""" + return getattr(self.classifier, "max_iter", 1000) + + def __getstate__(self) -> dict[str, Any]: + """Custom pickle method - don't save the ESM model""" + state = self.__dict__.copy() + # Remove the embedding_extractor (it will be recreated on load) + state.pop("embedding_extractor", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + """Custom unpickle method - recreate ESM model with correct config""" + self.__dict__.update(state) + + # Check for missing attributes from old model versions + warnings_issued = [] + if not hasattr(self, "batch_size"): + warnings_issued.append(f"batch_size (using default: {DEFAULT_BATCH_SIZE})") + if not hasattr(self, "revision"): + warnings_issued.append("revision (using default: 'main')") + + if warnings_issued: + import warnings + + warnings.warn( + f"Loading old model missing attributes: {', '.join(warnings_issued)}. " + "Predictions may differ from original model. Consider retraining with current version.", + UserWarning, + stacklevel=2, + ) + + # Recreate embedding extractor with fixed configuration + batch_size = getattr( + self, "batch_size", DEFAULT_BATCH_SIZE + ) # Default if not stored (backwards compatibility) + revision = getattr( + self, "revision", "main" + ) # Default if not stored (backwards compatibility) + self.embedding_extractor = ESMEmbeddingExtractor( + self.model_name, self.device, batch_size, revision=revision + ) diff --git a/src/antibody_training_esm/core/classifier_factory.py b/src/antibody_training_esm/core/classifier_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..de4342bb2b28534ed8ebae3bc5bd76f7ede84f07 --- /dev/null +++ b/src/antibody_training_esm/core/classifier_factory.py @@ -0,0 +1,140 @@ +""" +Classifier Factory + +Creates classifier strategies based on configuration. +Implements Factory Pattern for runtime strategy selection. + +Design Pattern: Factory (Gang of Four) +Purpose: Decouple classifier creation from BinaryClassifier + +Examples: + >>> # Logistic Regression (default) + >>> config = {"C": 1.0, "random_state": 42} + >>> clf = create_classifier(config) + >>> isinstance(clf, LogisticRegressionStrategy) + True + + >>> # XGBoost (Phase 2) + >>> config = {"type": "xgboost", "n_estimators": 100} + >>> clf = create_classifier(config) + >>> isinstance(clf, XGBoostStrategy) + True +""" + +from typing import Any + +from antibody_training_esm.core.classifier_strategy import ClassifierStrategy +from antibody_training_esm.core.strategies.logistic_regression import ( + LogisticRegressionStrategy, +) + + +def create_classifier(config: dict[str, Any]) -> ClassifierStrategy: + """ + Factory function for creating classifier strategies. + + Args: + config: Configuration dictionary with "type" key and hyperparameters + + Returns: + ClassifierStrategy instance (LogReg, XGBoost, etc.) + + Raises: + ValueError: If classifier type is unknown + + Notes: + - Defaults to "logistic_regression" if "type" not specified (backward compat) + - Supported types: "logistic_regression", "xgboost" (Phase 2) + + Examples: + >>> # Logistic Regression (explicit) + >>> config = {"type": "logistic_regression", "C": 1.0} + >>> clf = create_classifier(config) + + >>> # Logistic Regression (implicit default) + >>> config = {"C": 1.0} # No "type" field + >>> clf = create_classifier(config) + + >>> # XGBoost (Phase 2) + >>> config = {"type": "xgboost", "n_estimators": 100} + >>> clf = create_classifier(config) + """ + # Default to logistic_regression for backward compatibility + classifier_type = config.get("type", "logistic_regression") + + if classifier_type == "logistic_regression": + return LogisticRegressionStrategy(config) + elif classifier_type == "xgboost": + # Phase 2: XGBoost implementation + try: + from antibody_training_esm.core.strategies.xgboost_strategy import ( + XGBoostStrategy, + ) + + return XGBoostStrategy(config) + except ImportError as e: + raise ImportError( + "XGBoost classifier requested but xgboost not installed. " + "Install with: pip install xgboost>=2.0.0" + ) from e + else: + raise ValueError( + f"Unknown classifier type: '{classifier_type}'. " + f"Supported types: logistic_regression, xgboost" + ) + + +# Registry pattern for extensibility (future: plugin system) +CLASSIFIER_REGISTRY: dict[str, type[ClassifierStrategy]] = { + "logistic_regression": LogisticRegressionStrategy, +} + + +def register_classifier(name: str, strategy_class: type[ClassifierStrategy]) -> None: + """ + Register a new classifier strategy (for plugins/extensions). + + Args: + name: Classifier type name (e.g., "mlp", "svm") + strategy_class: ClassifierStrategy implementation + + Examples: + >>> # Future: MLP classifier + >>> class MLPStrategy(ClassifierStrategy): + ... pass + >>> register_classifier("mlp", MLPStrategy) + """ + CLASSIFIER_REGISTRY[name] = strategy_class + + +def create_classifier_from_registry(config: dict[str, Any]) -> ClassifierStrategy: + """ + Create classifier using registry (extensible version). + + Args: + config: Configuration with "type" key + + Returns: + ClassifierStrategy instance + + Raises: + ValueError: If type not in registry + + Notes: + This function enables plugin-based classifier registration. + Use for third-party classifiers or experimental implementations. + """ + classifier_type = config.get("type", "logistic_regression") + + if classifier_type not in CLASSIFIER_REGISTRY: + raise ValueError( + f"Unknown classifier type: '{classifier_type}'. " + f"Registered types: {list(CLASSIFIER_REGISTRY.keys())}" + ) + + strategy_class = CLASSIFIER_REGISTRY[classifier_type] + # Cast to Any to bypass protocol constructor checks + # (Protocol doesn't strictly define __init__ signature) + from typing import cast + + return cast(ClassifierStrategy, cast(Any, strategy_class)(config)) diff --git a/src/antibody_training_esm/core/classifier_strategy.py b/src/antibody_training_esm/core/classifier_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..12f972c4107840e819f32a4281658567a5fa285f --- /dev/null +++ b/src/antibody_training_esm/core/classifier_strategy.py @@ -0,0 +1,248 @@ +""" +Classifier Strategy Protocol + +Defines the interface for classifier backends (LogReg, XGBoost, MLP, etc.) +Uses Protocol for structural subtyping - sklearn compatibility without inheritance. + +This module implements the Strategy Pattern for classifier algorithms, enabling +runtime swapping of different classifier backends while maintaining a consistent +interface. + +Design Pattern: Strategy (Gang of Four) +Type System: Protocol-based structural subtyping (PEP 544) + +Examples: + >>> # Any class implementing these methods can be used as a classifier + >>> from sklearn.linear_model import LogisticRegression + >>> clf = LogisticRegression() + >>> isinstance(clf, ClassifierStrategy) # True (runtime_checkable) + True + + >>> # Custom strategy implementation + >>> class MyStrategy: + ... def fit(self, X, y): ... + ... def predict(self, X): ... + ... def predict_proba(self, X): ... + ... def get_params(self, deep=True): ... + ... @property + ... def classes_(self): ... + >>> isinstance(MyStrategy(), ClassifierStrategy) # True + True +""" + +from typing import Any, Protocol, runtime_checkable + +import numpy as np + + +@runtime_checkable +class ClassifierStrategy(Protocol): + """ + Protocol for classifier strategies. + + Any class implementing these methods can be used as a classifier backend, + including sklearn estimators (LogisticRegression, XGBClassifier, etc.) + + This protocol defines the minimal interface required by BinaryClassifier. + It follows the sklearn estimator API for maximum compatibility. + + Notes: + - Uses Protocol for duck typing (PEP 544) + - runtime_checkable enables isinstance() checks + - Minimal interface - only what BinaryClassifier needs + - Compatible with sklearn cross_val_score and GridSearchCV + + See Also: + - sklearn.base.BaseEstimator + - sklearn.base.ClassifierMixin + - PEP 544: https://www.python.org/dev/peps/pep-0544/ + """ + + def fit(self, X: np.ndarray, y: np.ndarray) -> None: + """ + Train the classifier on embeddings. + + Args: + X: Embeddings array, shape (n_samples, n_features) + Each row is an ESM embedding vector for one antibody sequence. + y: Labels array, shape (n_samples,) + Binary labels (0 = non-polyreactive, 1 = polyreactive) + + Raises: + ValueError: If X or y have invalid shapes, contain NaN/inf, or + have mismatched dimensions + + Notes: + After calling fit(), the classifier must set the classes_ attribute + to enable sklearn compatibility (required for cross_val_score). + + Examples: + >>> X_train = np.random.rand(100, 1280) # 100 ESM1v embeddings + >>> y_train = np.array([0, 1] * 50) # Binary labels + >>> clf.fit(X_train, y_train) + >>> hasattr(clf, 'classes_') + True + """ + ... + + def predict(self, X: np.ndarray) -> np.ndarray: + """ + Predict class labels for embeddings. + + Args: + X: Embeddings array, shape (n_samples, n_features) + Each row is an ESM embedding vector for one antibody sequence. + + Returns: + Predicted labels, shape (n_samples,) + Binary labels (0 = non-polyreactive, 1 = polyreactive) + + Raises: + ValueError: If classifier not fitted (must call fit() first) + ValueError: If X has invalid shape or contains NaN/inf + ValueError: If X.shape[1] != n_features_in_ (mismatched dimensions) + + Examples: + >>> X_test = np.random.rand(20, 1280) + >>> predictions = clf.predict(X_test) + >>> predictions.shape + (20,) + >>> set(predictions) + {0, 1} + """ + ... + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + """ + Predict class probabilities for embeddings. + + Args: + X: Embeddings array, shape (n_samples, n_features) + Each row is an ESM embedding vector for one antibody sequence. + + Returns: + Probability array, shape (n_samples, n_classes) + For binary classification: + - [:, 0] = P(non-polyreactive) + - [:, 1] = P(polyreactive) + Probabilities sum to 1.0 for each row. + + Raises: + ValueError: If classifier not fitted (must call fit() first) + ValueError: If X has invalid shape or contains NaN/inf + + Notes: + Used by BinaryClassifier for threshold-based prediction: + - ELISA assay: threshold = 0.5 + - PSR assay: threshold = 0.5495 (Novo Nordisk exact parity) + + Examples: + >>> X_test = np.random.rand(20, 1280) + >>> probs = clf.predict_proba(X_test) + >>> probs.shape + (20, 2) + >>> np.allclose(probs.sum(axis=1), 1.0) + True + """ + ... + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """ + Get classifier hyperparameters (sklearn API). + + Args: + deep: If True, return parameters for nested estimators + (e.g., for sklearn Pipelines or ensemble methods) + + Returns: + Dictionary of hyperparameters + + Notes: + Required for sklearn compatibility: + - cross_val_score needs get_params() for cloning estimators + - GridSearchCV needs get_params() for hyperparameter tuning + - set_params() needs get_params() for validation + + Examples: + >>> params = clf.get_params() + >>> 'C' in params # LogisticRegression + True + >>> 'n_estimators' in params # XGBoost + True + """ + ... + + @property + def classes_(self) -> np.ndarray: + """ + Class labels discovered during fit(). + + Returns: + Array of class labels, shape (n_classes,) + For binary classification: np.array([0, 1]) + + Raises: + AttributeError: If classifier not fitted (must call fit() first) + + Notes: + Required for sklearn compatibility: + - cross_val_score checks classes_ to determine estimator type + - sklearn 1.7+ requires _estimator_type and classes_ for CV + + Examples: + >>> clf.fit(X_train, y_train) + >>> clf.classes_ + array([0, 1]) + """ + ... + + def score(self, X: np.ndarray, y: np.ndarray) -> float: + """ + Return the mean accuracy on the given test data and labels. + + Args: + X: Test samples. + y: True labels for X. + + Returns: + Mean accuracy of self.predict(X) wrt. y. + """ + ... + + def to_dict(self) -> dict[str, Any]: + """ + Serialize classifier hyperparameters to dictionary (for JSON). + + Returns: + Dictionary with all hyperparameters and metadata. + Does NOT include fitted state (arrays) - use to_arrays() for that. + """ + ... + + def to_arrays(self) -> dict[str, np.ndarray]: + """ + Extract fitted state as numpy arrays (for NPZ). + + Returns: + Dictionary of arrays representing the fitted model. + + Raises: + ValueError: If classifier not fitted (must call fit() first) + """ + ... + + @classmethod + def from_dict( + cls, config: dict[str, Any], arrays: dict[str, np.ndarray] | None = None + ) -> "ClassifierStrategy": + """ + Deserialize classifier from dict + arrays. + + Args: + config: Dictionary with hyperparameters (from JSON) + arrays: Dictionary with fitted state (from NPZ), None if unfitted + + Returns: + Reconstructed classifier instance + """ + ... diff --git a/src/antibody_training_esm/core/config.py b/src/antibody_training_esm/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b92e89f257c61550c4cb75a0939d825a4e6962f3 --- /dev/null +++ b/src/antibody_training_esm/core/config.py @@ -0,0 +1,13 @@ +""" +Core configuration defaults. + +Centralizes the magic numbers used across the training pipeline so they can be +updated in one place (or overridden by CLI/config files later). +""" + +DEFAULT_BATCH_SIZE = 32 +DEFAULT_MAX_SEQ_LENGTH = 1024 +GPU_CACHE_CLEAR_INTERVAL = 10 # Clear GPU cache every N batches to prevent OOM +ERROR_PREVIEW_LIMIT = 10 # Show first N errors in validation messages +LOG_SEPARATOR_WIDTH = 60 # Width for log separator lines in training output +SEQUENCE_PREVIEW_LENGTH = 50 # Max chars for sequence previews in logs/errors diff --git a/src/antibody_training_esm/core/directory_utils.py b/src/antibody_training_esm/core/directory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7f8101405dd6f3b0b424e275c0ede8ccc62ad2aa --- /dev/null +++ b/src/antibody_training_esm/core/directory_utils.py @@ -0,0 +1,143 @@ +""" +Directory utilities for organizing model artifacts hierarchically + +Implements standardized directory structure: + experiments/checkpoints/{model_shortname}/{classifier_type}/{model_files} + experiments/benchmarks/{model_shortname}/{classifier_type}/{dataset}/{results} +""" + +import re +from pathlib import Path +from typing import Any + + +def extract_model_shortname(model_name: str) -> str: + """ + Extract short model identifier from HuggingFace model name + + Examples: + facebook/esm1v_t33_650M_UR90S_1 -> esm1v + facebook/esm2_t33_650M_UR50D -> esm2_650m + alchemab/antiberta2 -> antiberta + + Args: + model_name: Full HuggingFace model name + + Returns: + Short model identifier + """ + # Handle facebook/esm models + if "esm1v" in model_name.lower(): + return "esm1v" + elif "esm2" in model_name.lower(): + # Extract size info (e.g., 650M, 3B) + match = re.search(r"esm2.*?(\d+[MB])", model_name, re.IGNORECASE) + if match: + size = match.group(1).lower() + return f"esm2_{size}" + return "esm2" + elif "antiberta" in model_name.lower(): + return "antiberta" + elif "protbert" in model_name.lower() or "prot_bert" in model_name.lower(): + return "protbert" + elif "ablang" in model_name.lower(): + return "ablang" + else: + # Fallback: use the last part of the model path + return model_name.split("/")[-1].lower() + + +def extract_classifier_shortname(classifier_config: dict[str, Any]) -> str: + """ + Extract short classifier identifier from classifier config + + Examples: + {"type": "logistic_regression", ...} -> logreg + {"type": "xgboost", ...} -> xgboost + {"type": "mlp", ...} -> mlp + + Args: + classifier_config: Classifier configuration dictionary + + Returns: + Short classifier identifier + """ + classifier_type = classifier_config.get("type", "unknown") + + # Map full names to short names + shortname_map = { + "logistic_regression": "logreg", + "xgboost": "xgboost", + "mlp": "mlp", + "svm": "svm", + "random_forest": "rf", + } + + return str(shortname_map.get(classifier_type, classifier_type)) + + +def get_hierarchical_model_dir( + base_dir: str, + model_name: str, + classifier_config: dict[str, Any], +) -> Path: + """ + Generate hierarchical model directory path + + Structure: {base_dir}/{model_shortname}/{classifier_type}/ + + Args: + base_dir: Base models directory (e.g., "./models") + model_name: Full HuggingFace model name + classifier_config: Classifier configuration dictionary + + Returns: + Path to hierarchical model directory + + Examples: + >>> get_hierarchical_model_dir( + ... "./models", + ... "facebook/esm1v_t33_650M_UR90S_1", + ... {"type": "logistic_regression"} + ... ) + PosixPath('experiments/checkpoints/esm1v/logreg') + """ + model_short = extract_model_shortname(model_name) + classifier_short = extract_classifier_shortname(classifier_config) + + return Path(base_dir) / model_short / classifier_short + + +def get_hierarchical_test_results_dir( + base_dir: str, + model_name: str, + classifier_config: dict[str, Any], + dataset_name: str, +) -> Path: + """ + Generate hierarchical test results directory path + + Structure: {base_dir}/{model_shortname}/{classifier_type}/{dataset}/ + + Args: + base_dir: Base test results directory (e.g., "./experiments/benchmarks") + model_name: Full HuggingFace model name + classifier_config: Classifier configuration dictionary + dataset_name: Dataset name (e.g., "jain", "harvey") + + Returns: + Path to hierarchical test results directory + + Examples: + >>> get_hierarchical_test_results_dir( + ... "./experiments/benchmarks", + ... "facebook/esm1v_t33_650M_UR90S_1", + ... {"type": "logistic_regression"}, + ... "jain" + ... ) + PosixPath('experiments/benchmarks/esm1v/logreg/jain') + """ + model_short = extract_model_shortname(model_name) + classifier_short = extract_classifier_shortname(classifier_config) + + return Path(base_dir) / model_short / classifier_short / dataset_name diff --git a/src/antibody_training_esm/core/embeddings.py b/src/antibody_training_esm/core/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..f842eee5e8e27ea6e94dd38cbef75acad17fd966 --- /dev/null +++ b/src/antibody_training_esm/core/embeddings.py @@ -0,0 +1,300 @@ +""" +ESM Embedding Module + +Professional module for ESM-1V protein sequence embedding extraction. +Handles batch processing, GPU memory management, and validation. +""" + +import logging + +import numpy as np +import torch +from tqdm import tqdm +from transformers import AutoModel, AutoTokenizer + +from .config import ( + DEFAULT_BATCH_SIZE, + DEFAULT_MAX_SEQ_LENGTH, + ERROR_PREVIEW_LIMIT, + GPU_CACHE_CLEAR_INTERVAL, + SEQUENCE_PREVIEW_LENGTH, +) + +logger = logging.getLogger(__name__) + + +class ESMEmbeddingExtractor: + """Extract ESM-1V embeddings for protein sequences with proper batching and GPU management""" + + def __init__( + self, + model_name: str, + device: str, + batch_size: int = DEFAULT_BATCH_SIZE, + max_length: int = DEFAULT_MAX_SEQ_LENGTH, + revision: str = "main", + ): + """ + Initialize ESM embedding extractor + + Args: + model_name: HuggingFace model identifier (e.g., 'facebook/esm1v_t33_650M_UR90S_1') + device: Device to run model on ('cpu', 'cuda', or 'mps') + batch_size: Number of sequences to process per batch + max_length: Maximum sequence length for tokenizer truncation/padding + revision: HuggingFace model revision (commit SHA or branch name) for reproducibility + """ + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.max_length = max_length + self.revision = revision + + # Load model with output_hidden_states enabled + pinned revision for reproducibility + self.model = AutoModel.from_pretrained( + model_name, + output_hidden_states=True, + revision=revision, # nosec B615 - Pinned to specific version for scientific reproducibility + ) + self.model.to(device) + self.model.eval() + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + revision=revision, # nosec B615 - Pinned to specific version for scientific reproducibility + ) # type: ignore[no-untyped-call] # HuggingFace transformers lacks type stubs + logger.info( + f"ESM model {model_name} (revision={revision}) loaded on {device} " + f"with batch_size={batch_size} and max_length={max_length}" + ) + + def embed_sequence(self, sequence: str) -> np.ndarray: + """ + Extract ESM-1V embedding for a single protein sequence + + Args: + sequence: Amino acid sequence string + + Returns: + Embedding vector as numpy array + + Raises: + ValueError: If sequence contains invalid amino acids or is too short + """ + try: + # Validate sequence (20 standard amino acids + X for unknown/ambiguous) + # X is supported by ESM tokenizer for ambiguous residues + valid_aas = set("ACDEFGHIKLMNPQRSTVWYX") + sequence = sequence.upper().strip() + + if not all(aa in valid_aas for aa in sequence): + raise ValueError("Invalid amino acid characters in sequence") + + if len(sequence) < 1: + raise ValueError("Sequence too short") + + # Tokenize the sequence + inputs = self.tokenizer( + sequence, + return_tensors="pt", + truncation=True, + max_length=self.max_length, + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get embeddings + with torch.no_grad(): + outputs = self.model(**inputs, output_hidden_states=True) + embeddings = outputs.hidden_states[-1] # (batch, seq_len, hidden_dim) + + # Use attention mask to properly exclude padding and special tokens + attention_mask = inputs["attention_mask"].unsqueeze( + -1 + ) # (batch, seq_len, 1) + + # Mask out special tokens (first and last) + attention_mask[:, 0, :] = 0 # CLS token + attention_mask[:, -1, :] = 0 # EOS token + + # Masked mean pooling + masked_embeddings = embeddings * attention_mask + sum_embeddings = masked_embeddings.sum(dim=1) # Sum over sequence + sum_mask = attention_mask.sum(dim=1) # Count valid tokens + + # Prevent division by zero (NaN embeddings) + if sum_mask.item() == 0: + raise ValueError( + f"Attention mask is all zeros for sequence (length: {len(sequence)}). " + f"Sequence preview: '{sequence[:SEQUENCE_PREVIEW_LENGTH]}...'. " + "This typically indicates an empty or invalid sequence after masking." + ) + + mean_embeddings = sum_embeddings / sum_mask # Average + + result: np.ndarray = mean_embeddings.squeeze(0).cpu().numpy() + return result + + except Exception as e: + # Add sequence context to error message (truncate for readability) + seq_preview = ( + sequence[:SEQUENCE_PREVIEW_LENGTH] + "..." + if len(sequence) > SEQUENCE_PREVIEW_LENGTH + else sequence + ) + logger.error( + f"Error getting embeddings for sequence (length={len(sequence)}): {seq_preview}" + ) + raise RuntimeError( + f"Failed to extract embedding for sequence of length {len(sequence)}: {seq_preview}" + ) from e + + def extract_batch_embeddings(self, sequences: list[str]) -> np.ndarray: + """ + Extract embeddings for multiple sequences using efficient batching + + Args: + sequences: List of amino acid sequence strings + + Returns: + Array of embeddings with shape (n_sequences, embedding_dim) + """ + embeddings_list = [] + + logger.info( + f"Extracting embeddings for {len(sequences)} sequences with batch_size={self.batch_size}..." + ) + + # Process sequences in batches + num_batches = (len(sequences) + self.batch_size - 1) // self.batch_size + + for batch_idx in tqdm(range(num_batches), desc="Processing batches"): + start_idx = batch_idx * self.batch_size + end_idx = min(start_idx + self.batch_size, len(sequences)) + batch_sequences = sequences[start_idx:end_idx] + + try: + # Validate and clean sequences + valid_aas = set("ACDEFGHIKLMNPQRSTVWYX") + cleaned_sequences: list[str] = [] + invalid_sequences: list[ + tuple[int, str, str] + ] = [] # (index, sequence, reason) + + for seq_idx, seq in enumerate(batch_sequences): + seq = seq.upper().strip() + global_idx = start_idx + seq_idx + + # Check for empty/short sequences + if len(seq) < 1: + invalid_sequences.append( + (global_idx, seq, "empty or too short") + ) + continue + + # Check for invalid amino acids + invalid_chars = [aa for aa in seq if aa not in valid_aas] + if invalid_chars: + reason = f"invalid characters: {set(invalid_chars)}" + invalid_sequences.append( + (global_idx, seq[:SEQUENCE_PREVIEW_LENGTH], reason) + ) + continue + + cleaned_sequences.append(seq) + + # If any sequences are invalid, fail immediately + if invalid_sequences: + error_details = "\n".join( + f" Index {idx}: '{seq}...' ({reason})" + for idx, seq, reason in invalid_sequences[:ERROR_PREVIEW_LIMIT] + ) + total_invalid = len(invalid_sequences) + raise ValueError( + f"Found {total_invalid} invalid sequence(s) in batch {batch_idx}:\n{error_details}" + + ( + f"\n ... and {total_invalid - ERROR_PREVIEW_LIMIT} more" + if total_invalid > ERROR_PREVIEW_LIMIT + else "" + ) + ) + + # Tokenize the batch with padding + inputs = self.tokenizer( + cleaned_sequences, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get embeddings for the batch + with torch.no_grad(): + outputs = self.model(**inputs, output_hidden_states=True) + embeddings = outputs.hidden_states[ + -1 + ] # (batch, seq_len, hidden_dim) + + # Use attention mask to properly exclude padding and special tokens + attention_mask = inputs["attention_mask"].unsqueeze( + -1 + ) # (batch, seq_len, 1) + + # Mask out special tokens (first and last) + attention_mask[:, 0, :] = 0 # CLS token + attention_mask[:, -1, :] = 0 # EOS token + + # Masked mean pooling + masked_embeddings = embeddings * attention_mask + sum_embeddings = masked_embeddings.sum(dim=1) # Sum over sequence + sum_mask = attention_mask.sum(dim=1) # Count valid tokens + + # Prevent division by zero (NaN embeddings) + # Use clamp to avoid zero divisors (min valid tokens = 1) + sum_mask_safe = sum_mask.clamp(min=1e-9) + mean_embeddings = sum_embeddings / sum_mask_safe # Average + + # Check if any sequences had zero mask (would produce near-zero or invalid embeddings) + zero_mask_indices = ( + (sum_mask == 0).any(dim=1).nonzero(as_tuple=True)[0] + ) + if len(zero_mask_indices) > 0: + bad_seqs = [ + cleaned_sequences[i.item()][:SEQUENCE_PREVIEW_LENGTH] + for i in zero_mask_indices[:3] + ] + raise ValueError( + f"Found {len(zero_mask_indices)} sequence(s) with zero attention mask in batch {batch_idx}. " + f"Sample sequences: {bad_seqs}. This indicates empty/invalid sequences after masking." + ) + + # Convert to numpy and add to list + batch_embeddings = mean_embeddings.cpu().numpy() + for emb in batch_embeddings: + embeddings_list.append(emb) + + # Clear GPU cache periodically to prevent OOM + if (batch_idx + 1) % GPU_CACHE_CLEAR_INTERVAL == 0: + self._clear_gpu_cache() + + except Exception as e: + logger.error( + f"CRITICAL: Failed to process batch {batch_idx} (sequences {start_idx}-{end_idx}): {e}" + ) + logger.error( + f"First sequence in failed batch: {batch_sequences[0][:100]}..." + ) + raise RuntimeError( + f"Batch processing failed at batch {batch_idx}. Cannot continue with corrupted embeddings. " + f"Original error: {e}" + ) from e + + # Final cache clear + self._clear_gpu_cache() + return np.array(embeddings_list) + + def _clear_gpu_cache(self) -> None: + """Clear GPU cache for CUDA or MPS devices to prevent memory leaks""" + if str(self.device).startswith("cuda"): + torch.cuda.empty_cache() + elif str(self.device).startswith("mps"): + torch.mps.empty_cache() diff --git a/src/antibody_training_esm/core/prediction.py b/src/antibody_training_esm/core/prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..6e7e0b1adb66001205153078d86b1d61a437f954 --- /dev/null +++ b/src/antibody_training_esm/core/prediction.py @@ -0,0 +1,320 @@ +import logging +from pathlib import Path +from typing import cast + +import joblib +import numpy as np +import pandas as pd +import torch +from omegaconf import DictConfig +from sklearn.linear_model import LogisticRegression + +from antibody_training_esm.core.classifier import BinaryClassifier +from antibody_training_esm.core.config import DEFAULT_BATCH_SIZE +from antibody_training_esm.core.embeddings import ESMEmbeddingExtractor +from antibody_training_esm.core.trainer import load_model_from_npz +from antibody_training_esm.models.prediction import ( + AssayType, + PredictionRequest, + PredictionResult, +) + +logger = logging.getLogger(__name__) + + +class Predictor: + """ + A class to handle the antibody non-specificity prediction pipeline. + + This class encapsulates the model loading, embedding extraction, and prediction logic. + It follows the principle of 'prepare once, execute many' (though for CLI it's usually once). + """ + + def __init__( + self, + model_name: str, + classifier_path: str, + device: str | None = None, + config_path: str | None = None, + ): + """ + Initialize the Predictor with model configurations. + + Args: + model_name: The name of the ESM model to use (e.g. 'facebook/esm1v_t33_650M_UR90S_1'). + classifier_path: Path to the trained scikit-learn classifier (pickle/joblib file) or NPZ weights. + device: The device to run the model on ('cpu' or 'cuda'). If None, auto-detects. + config_path: Path to the JSON config file (required if classifier_path is .npz). + """ + self.device = self._select_device(device) + self.model_name = model_name + self.classifier_path = classifier_path + self.config_path = config_path + + self._embedder: ESMEmbeddingExtractor | None = None + self._classifier: BinaryClassifier | LogisticRegression | None = None + + @property + def classifier(self) -> BinaryClassifier | LogisticRegression: + """ + Lazy loads the classifier. + + Supports: + 1. Legacy Pickle (.pkl): Loaded via joblib. + 2. Production NPZ (.npz): Loaded via load_model_from_npz using accompanying JSON config. + """ + if self._classifier is None: + path_obj = Path(self.classifier_path) + + if path_obj.suffix == ".npz": + # NPZ loading path + if self.config_path: + json_path = Path(self.config_path) + else: + # Infer JSON path: model.npz -> model_config.json + json_path = path_obj.with_name(f"{path_obj.stem}_config.json") + + if not json_path.exists(): + raise FileNotFoundError( + f"JSON config not found at {json_path}. " + "For .npz models, a corresponding JSON config is required. " + "Specify it explicitly with config_path if the naming convention differs." + ) + + logger.info(f"Loading model from NPZ: {path_obj} (Config: {json_path})") + self._classifier = load_model_from_npz(str(path_obj), str(json_path)) + + else: + # Legacy/Pickle loading path + logger.info(f"Loading model from Pickle: {path_obj}") + self._classifier = joblib.load(self.classifier_path) + + return self._classifier + + @property + def embedder(self) -> ESMEmbeddingExtractor: + """ + Lazy loads the ESM embedding extractor. + + Optimization: + If the loaded classifier is a BinaryClassifier instance (which contains + its own embedding_extractor), we reuse it to avoid double-loading + the 650MB model into GPU/CPU memory. + """ + if self._embedder is None: + # First ensure classifier is loaded (it might have the embedder) + clf = self.classifier + + # Check if it's our BinaryClassifier wrapper that has an embedder + if ( + hasattr(clf, "embedding_extractor") + and clf.embedding_extractor is not None + ): + embedder = clf.embedding_extractor + + # If the persisted embedder device doesn't match requested device, + # recreate it to avoid MPS/CUDA mismatches (common segfault source on macOS). + if self.device and str(embedder.device) != self.device: + batch_size = getattr(embedder, "batch_size", DEFAULT_BATCH_SIZE) + revision = getattr(embedder, "revision", "main") + logger.info( + "Recreating embedder on requested device %s (was %s)", + self.device, + embedder.device, + ) + embedder = ESMEmbeddingExtractor( + model_name=self.model_name, + device=self.device, + batch_size=batch_size, + revision=revision, + ) + + self._embedder = embedder + else: + # Fallback: Create a new one (e.g., if using raw sklearn model) + self._embedder = ESMEmbeddingExtractor( + model_name=self.model_name, + device=self.device, + ) + return self._embedder + + def predict( + self, + sequences: list[str], + threshold: float = 0.5, + assay_type: AssayType | None = None, + ) -> pd.DataFrame: + """ + Predict specificity for a list of sequences. + + Args: + sequences: A list of antibody amino acid sequences. + threshold: Decision threshold (default: 0.5). + assay_type: 'PSR' or 'ELISA' to use calibrated thresholds (overrides threshold). + + Returns: + A DataFrame containing 'prediction' (string) and 'probability' (float) columns. + """ + if not sequences: + return pd.DataFrame(columns=["prediction", "probability"]) + + # Generate embeddings + embeddings = self.embedder.extract_batch_embeddings(sequences) + + # Make predictions + # Check if the classifier supports the custom 'predict' signature with assay_type + # (Our BinaryClassifier does, standard sklearn does not) + if ( + hasattr(self.classifier, "predict") + and "assay_type" in self.classifier.predict.__code__.co_varnames + ): + predictions = self.classifier.predict( + embeddings, threshold=threshold, assay_type=assay_type + ) + else: + # Standard sklearn behavior + probabilities = self.classifier.predict_proba(embeddings) + predictions = (probabilities[:, 1] > threshold).astype(int) + + # Get probabilities (universal) + probabilities = self.classifier.predict_proba(embeddings) + + # Ensure probabilities is a numpy array + if isinstance(probabilities, list): + probabilities = np.array(probabilities) + + # Format results + results = pd.DataFrame( + { + "prediction": [ + "non-specific" if p == 1 else "specific" for p in predictions + ], + "probability": probabilities[ + :, 1 + ], # Probability of class 1 (non-specific) + } + ) + + return results + + def predict_dataframe( + self, + df: pd.DataFrame, + sequence_col: str = "sequence", + threshold: float = 0.5, + assay_type: AssayType | None = None, + ) -> pd.DataFrame: + """ + Predict specificity for sequences in a DataFrame and append results. + + Args: + df: Input DataFrame. + sequence_col: Name of the column containing sequences. + threshold: Decision threshold. + assay_type: 'PSR' or 'ELISA' (overrides threshold). + + Returns: + A copy of the input DataFrame with 'prediction' and 'probability' columns appended. + """ + if sequence_col not in df.columns: + raise ValueError(f"Input DataFrame must contain a '{sequence_col}' column.") + + sequences = df[sequence_col].tolist() + results = self.predict(sequences, threshold=threshold, assay_type=assay_type) + + output_df = df.copy() + output_df["prediction"] = results["prediction"].values + output_df["probability"] = results["probability"].values + + return output_df + + def predict_single( + self, + sequence: str | PredictionRequest, + threshold: float = 0.5, + assay_type: AssayType | None = None, + ) -> PredictionResult: + """ + Predict single sequence with Pydantic validation. + + Args: + sequence: Raw string OR PredictionRequest model + threshold: Decision threshold (ignored if PredictionRequest passed) + assay_type: Assay type (ignored if PredictionRequest passed) + + Returns: + PredictionResult model + """ + # Normalize input to PredictionRequest + if isinstance(sequence, str): + request = PredictionRequest( + sequence=sequence, + threshold=threshold, + assay_type=assay_type, + ) + else: + request = sequence + + # Extract validated sequence + cleaned_seq = request.sequence + + # Run prediction (existing logic) + results_df = self.predict( + [cleaned_seq], + threshold=request.threshold, + assay_type=request.assay_type, + ) + + # Convert to PredictionResult + return PredictionResult( + sequence=cleaned_seq, + prediction=results_df["prediction"].iloc[0], + probability=float(results_df["probability"].iloc[0]), + threshold=request.threshold, + assay_type=request.assay_type, + ) + + @staticmethod + def _select_device(device: str | None) -> str: + """ + Select the best available device. + + Prioritizes CUDA, then MPS (macOS), then CPU. + """ + if device: + return device + + if torch.cuda.is_available(): + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + return "cpu" + + +def run_prediction(input_df: pd.DataFrame, cfg: DictConfig) -> pd.DataFrame: + """ + Helper function to run prediction using Hydra config. + + Args: + input_df: DataFrame containing an sequence column. + cfg: The Hydra configuration object. + + Returns: + DataFrame with 'prediction' and 'probability' columns added. + """ + config_path = getattr(cfg.classifier, "config_path", None) + + predictor = Predictor( + model_name=cfg.model.name, + classifier_path=cfg.classifier.path, + config_path=config_path, + ) + + # Extract config parameters with defaults + sequence_col = getattr(cfg, "sequence_column", "sequence") + threshold = getattr(cfg, "threshold", 0.5) + assay_type = cast(AssayType | None, getattr(cfg, "assay_type", None)) + + return predictor.predict_dataframe( + input_df, sequence_col=sequence_col, threshold=threshold, assay_type=assay_type + ) diff --git a/src/antibody_training_esm/core/strategies/__init__.py b/src/antibody_training_esm/core/strategies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7694620270e75064033f5a0b8c06661c47176cbb --- /dev/null +++ b/src/antibody_training_esm/core/strategies/__init__.py @@ -0,0 +1,33 @@ +""" +Classifier Strategy Implementations + +This package contains concrete implementations of the ClassifierStrategy protocol. + +Available Strategies: + - LogisticRegressionStrategy: Wrapper for sklearn LogisticRegression + - XGBoostStrategy: Wrapper for xgboost XGBClassifier (requires xgboost>=2.0.0) + +Usage: + >>> from antibody_training_esm.core.strategies import LogisticRegressionStrategy + >>> config = {"C": 1.0, "random_state": 42} + >>> clf = LogisticRegressionStrategy(config) + >>> clf.fit(X_train, y_train) + >>> predictions = clf.predict(X_test) +""" + +from antibody_training_esm.core.strategies.logistic_regression import ( + LogisticRegressionStrategy, +) + +__all__ = [ + "LogisticRegressionStrategy", +] + +# XGBoostStrategy will be added in Phase 2 +try: + from antibody_training_esm.core.strategies.xgboost_strategy import XGBoostStrategy + + __all__.append("XGBoostStrategy") +except ImportError: + # xgboost not installed - skip XGBoostStrategy + pass diff --git a/src/antibody_training_esm/core/strategies/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/core/strategies/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16a00b3435c090dbc9bb2dd577c9ec7dc0619dcf Binary files /dev/null and b/src/antibody_training_esm/core/strategies/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/strategies/__pycache__/logistic_regression.cpython-312.pyc b/src/antibody_training_esm/core/strategies/__pycache__/logistic_regression.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79a9a62ca5461d85f1cc3875f55b029a98a7f4e0 Binary files /dev/null and b/src/antibody_training_esm/core/strategies/__pycache__/logistic_regression.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/strategies/__pycache__/xgboost_strategy.cpython-312.pyc b/src/antibody_training_esm/core/strategies/__pycache__/xgboost_strategy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a64c2ee46179ce975c8f966457ab2e33bef2b2a Binary files /dev/null and b/src/antibody_training_esm/core/strategies/__pycache__/xgboost_strategy.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/strategies/logistic_regression.py b/src/antibody_training_esm/core/strategies/logistic_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..2302946a79a5dcc5812e6f84dfdd0bceed60518e --- /dev/null +++ b/src/antibody_training_esm/core/strategies/logistic_regression.py @@ -0,0 +1,344 @@ +""" +Logistic Regression Classifier Strategy + +Wraps sklearn.linear_model.LogisticRegression as a ClassifierStrategy. +This is the EXISTING classifier, refactored to use the strategy pattern. + +The LogisticRegressionStrategy implements both the ClassifierStrategy protocol +(for training/prediction) and the SerializableClassifier protocol (for production +deployment without pickle files). + +Design Pattern: Strategy (Gang of Four) +Type System: Protocol-based structural subtyping + +Examples: + >>> # Basic usage + >>> config = {"C": 1.0, "random_state": 42} + >>> strategy = LogisticRegressionStrategy(config) + >>> strategy.fit(X_train, y_train) + >>> predictions = strategy.predict(X_test) + + >>> # Production serialization (pickle-free) + >>> config_dict = strategy.to_dict() + >>> arrays_dict = strategy.to_arrays() + >>> json.dump(config_dict, open("model.json", "w")) + >>> np.savez("model.npz", **arrays_dict) + >>> + >>> # Load model + >>> config = json.load(open("model.json")) + >>> arrays = dict(np.load("model.npz")) + >>> loaded = LogisticRegressionStrategy.from_dict(config, arrays) +""" + +from typing import Any + +import numpy as np +from sklearn.linear_model import LogisticRegression + + +class LogisticRegressionStrategy: + """ + Logistic Regression classifier strategy. + + Wraps sklearn LogisticRegression with SerializableClassifier interface. + Implements both training (fit/predict) and serialization (to_dict/from_dict). + + This class refactors the EXISTING LogisticRegression classifier from + BinaryClassifier into a separate strategy, enabling the Strategy Pattern + for supporting multiple classifier backends (XGBoost, MLP, etc.). + + Attributes: + classifier: sklearn LogisticRegression instance + C: Inverse regularization strength (default: 1.0) + penalty: Regularization type: 'l1', 'l2', 'elasticnet', 'none' (default: 'l2') + solver: Optimization algorithm: 'lbfgs', 'liblinear', 'saga', etc. (default: 'lbfgs') + max_iter: Maximum iterations for optimization (default: 1000) + random_state: Random seed for reproducibility (default: 42) + class_weight: Class weights: 'balanced', dict, or None (default: None) + + Notes: + - Default hyperparameters match the EXISTING BinaryClassifier behavior + - Implements ClassifierStrategy protocol for sklearn compatibility + - Implements SerializableClassifier protocol for production deployment + - No scaling applied (matches Novo Nordisk methodology) + + See Also: + - sklearn.linear_model.LogisticRegression + - docs/research/novo-parity.md (methodology) + """ + + def __init__(self, config: dict[str, Any]) -> None: + """ + Initialize LogisticRegression strategy. + + Args: + config: Configuration dictionary with hyperparameters. + All keys are optional (defaults provided). + + Configuration Keys: + - C: Inverse regularization strength (default: 1.0) + - penalty: Regularization type (default: 'l2') + - solver: Optimization algorithm (default: 'lbfgs') + - max_iter: Maximum iterations (default: 1000) + - random_state: Random seed (default: 42) + - class_weight: Class weights (default: None) + + Examples: + >>> # Default config + >>> strategy = LogisticRegressionStrategy({}) + >>> strategy.C + 1.0 + + >>> # Custom config + >>> config = {"C": 0.5, "penalty": "l1", "solver": "liblinear"} + >>> strategy = LogisticRegressionStrategy(config) + >>> strategy.C + 0.5 + """ + # Extract hyperparameters (enforce single source of truth from YAML) + # No hardcoded defaults in Python - config must be complete + self.C = config["C"] + self.penalty = config["penalty"] + self.solver = config["solver"] + self.max_iter = config["max_iter"] + self.random_state = config["random_state"] + self.class_weight = config["class_weight"] + + # Create sklearn LogisticRegression estimator + self.classifier = LogisticRegression( + C=self.C, + penalty=self.penalty, + solver=self.solver, + max_iter=self.max_iter, + random_state=self.random_state, + class_weight=self.class_weight, + ) + + # ======================================================================== + # ClassifierStrategy Protocol Methods + # ======================================================================== + + def fit(self, X: np.ndarray, y: np.ndarray) -> None: + """ + Fit LogisticRegression on embeddings. + + Args: + X: Embeddings array, shape (n_samples, n_features) + y: Labels array, shape (n_samples,) + + Notes: + No scaling is applied (matches Novo Nordisk methodology). + After fitting, the classes_ attribute is available. + + Examples: + >>> X_train = np.random.rand(100, 1280) + >>> y_train = np.array([0, 1] * 50) + >>> strategy.fit(X_train, y_train) + >>> strategy.classes_ + array([0, 1]) + """ + self.classifier.fit(X, y) + + def predict(self, X: np.ndarray) -> np.ndarray: + """ + Predict class labels. + + Args: + X: Embeddings array, shape (n_samples, n_features) + + Returns: + Predicted labels, shape (n_samples,) + + Examples: + >>> X_test = np.random.rand(20, 1280) + >>> predictions = strategy.predict(X_test) + >>> predictions.shape + (20,) + """ + result: np.ndarray = self.classifier.predict(X) + return result + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + """ + Predict class probabilities. + + Args: + X: Embeddings array, shape (n_samples, n_features) + + Returns: + Probability array, shape (n_samples, n_classes) + + Examples: + >>> X_test = np.random.rand(20, 1280) + >>> probs = strategy.predict_proba(X_test) + >>> probs.shape + (20, 2) + >>> np.allclose(probs.sum(axis=1), 1.0) + True + """ + result: np.ndarray = self.classifier.predict_proba(X) + return result + + def score(self, X: np.ndarray, y: np.ndarray) -> float: + """ + Return mean accuracy on test data. + + Args: + X: Embeddings array, shape (n_samples, n_features) + y: True labels, shape (n_samples,) + + Returns: + Mean accuracy score + + Examples: + >>> X_test = np.random.rand(20, 1280) + >>> y_test = np.array([0, 1] * 10) + >>> acc = strategy.score(X_test, y_test) + >>> 0.0 <= acc <= 1.0 + True + """ + result: float = self.classifier.score(X, y) + return result + + def get_params(self, deep: bool = True) -> dict[str, Any]: # noqa: ARG002 + """ + Get hyperparameters (sklearn API). + + Args: + deep: If True, return params for nested estimators (unused) + + Returns: + Dictionary of hyperparameters + + Examples: + >>> params = strategy.get_params() + >>> params['C'] + 1.0 + """ + return { + "type": "logistic_regression", + "C": self.C, + "penalty": self.penalty, + "solver": self.solver, + "max_iter": self.max_iter, + "random_state": self.random_state, + "class_weight": self.class_weight, + } + + @property + def classes_(self) -> np.ndarray: + """ + Class labels discovered during fit. + + Returns: + Array of class labels, shape (n_classes,) + + Raises: + AttributeError: If classifier not fitted + + Examples: + >>> strategy.fit(X_train, y_train) + >>> strategy.classes_ + array([0, 1]) + """ + result: np.ndarray = self.classifier.classes_ + return result + + # ======================================================================== + # SerializableClassifier Protocol Methods + # ======================================================================== + + def to_dict(self) -> dict[str, Any]: + """ + Serialize hyperparameters to dict (for JSON). + + Returns: + Dictionary with all hyperparameters and metadata. + Does NOT include fitted state (arrays) - use to_arrays() for that. + + Examples: + >>> config = strategy.to_dict() + >>> config['type'] + 'logistic_regression' + >>> config['C'] + 1.0 + + >>> # Save to JSON + >>> import json + >>> json.dump(config, open("model_config.json", "w")) + """ + return { + "type": "logistic_regression", + "C": self.C, + "penalty": self.penalty, + "solver": self.solver, + "max_iter": self.max_iter, + "random_state": self.random_state, + "class_weight": self.class_weight, + } + + def to_arrays(self) -> dict[str, np.ndarray]: + """ + Extract fitted state as arrays (for NPZ). + + Returns: + Dictionary of arrays representing the fitted model. + + Raises: + ValueError: If classifier not fitted + + Examples: + >>> strategy.fit(X_train, y_train) + >>> arrays = strategy.to_arrays() + >>> arrays.keys() + dict_keys(['coef', 'intercept', 'classes', 'n_features_in', 'n_iter']) + + >>> # Save to NPZ + >>> np.savez("model.npz", **arrays) + """ + if not hasattr(self.classifier, "coef_"): + raise ValueError("Classifier must be fitted before serialization") + + return { + "coef": self.classifier.coef_, + "intercept": self.classifier.intercept_, + "classes": self.classifier.classes_, + "n_features_in": np.array([self.classifier.n_features_in_]), + "n_iter": self.classifier.n_iter_, + } + + @classmethod + def from_dict( + cls, config: dict[str, Any], arrays: dict[str, np.ndarray] | None = None + ) -> "LogisticRegressionStrategy": + """ + Deserialize from dict + arrays. + + Args: + config: Dictionary with hyperparameters (from JSON) + arrays: Dictionary with fitted state (from NPZ), None if unfitted + + Returns: + Reconstructed LogisticRegressionStrategy instance + + Examples: + >>> # Load from JSON + NPZ + >>> import json + >>> config = json.load(open("model_config.json")) + >>> arrays = dict(np.load("model.npz")) + >>> strategy = LogisticRegressionStrategy.from_dict(config, arrays) + >>> strategy.predict(X_test) + array([0, 1, 0, ...]) + """ + # Create unfitted classifier + strategy = cls(config) + + # Restore fitted state if arrays provided + if arrays is not None: + strategy.classifier.coef_ = arrays["coef"] + strategy.classifier.intercept_ = arrays["intercept"] + strategy.classifier.classes_ = arrays["classes"] + strategy.classifier.n_features_in_ = int(arrays["n_features_in"][0]) + strategy.classifier.n_iter_ = arrays["n_iter"] + + return strategy diff --git a/src/antibody_training_esm/core/strategies/xgboost_strategy.py b/src/antibody_training_esm/core/strategies/xgboost_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..05f592d9cb186be1909742291a84038d50d0341e --- /dev/null +++ b/src/antibody_training_esm/core/strategies/xgboost_strategy.py @@ -0,0 +1,377 @@ +""" +XGBoost Classifier Strategy + +Wraps xgboost.XGBClassifier as a ClassifierStrategy. +Enables gradient boosting for nonlinear decision boundaries. + +The XGBoostStrategy implements the ClassifierStrategy protocol for training/prediction. +Uses XGBoost's native .xgb format for production serialization (pickle-free). + +Design Pattern: Strategy (Gang of Four) +Type System: Protocol-based structural subtyping + +Examples: + >>> # Basic usage + >>> config = {"n_estimators": 100, "random_state": 42} + >>> strategy = XGBoostStrategy(config) + >>> strategy.fit(X_train, y_train) + >>> predictions = strategy.predict(X_test) + + >>> # Production serialization (native .xgb format) + >>> strategy.save_model("model.xgb") + >>> config_dict = strategy.to_dict() + >>> json.dump(config_dict, open("model_config.json", "w")) + >>> + >>> # Load model + >>> config = json.load(open("model_config.json")) + >>> loaded = XGBoostStrategy.load_model("model.xgb", config) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +import xgboost as xgb + +if TYPE_CHECKING: + from antibody_training_esm.core.classifier_strategy import ClassifierStrategy + + +class XGBoostStrategy: + """ + XGBoost classifier strategy. + + Wraps xgboost.XGBClassifier with ClassifierStrategy interface. + Implements both training (fit/predict) and native serialization (save_model/load_model). + + XGBoost provides gradient boosting trees capable of learning nonlinear decision + boundaries, which can outperform linear models like LogisticRegression on + complex antibody polyreactivity patterns. + + Attributes: + classifier: xgboost.XGBClassifier instance + n_estimators: Number of boosting rounds (default: 100) + max_depth: Maximum tree depth (default: 6) + learning_rate: Boosting learning rate (default: 0.3) + subsample: Subsample ratio of training instances (default: 1.0) + colsample_bytree: Subsample ratio of features (default: 1.0) + reg_alpha: L1 regularization on weights (default: 0.0) + reg_lambda: L2 regularization on weights (default: 1.0) + random_state: Random seed for reproducibility (default: 42) + objective: Learning objective (default: "binary:logistic") + + Notes: + - Uses XGBoost's native .xgb serialization format (no pickle dependency) + - Supports GPU acceleration via device parameter + - Default hyperparameters are XGBoost defaults (good starting point) + - For production deployment, use save_model() + to_dict() (JSON + .xgb) + + See Also: + - xgboost.XGBClassifier + - docs/developer-guide/xgboost.md + """ + + def __init__(self, config: dict[str, Any]) -> None: + """ + Initialize XGBoost strategy. + + Args: + config: Configuration dictionary with hyperparameters. + All keys are optional (defaults provided). + + Configuration Keys: + - n_estimators: Number of trees (default: 100) + - max_depth: Maximum tree depth (default: 6) + - learning_rate: Boosting learning rate (default: 0.3) + - subsample: Subsample ratio (default: 1.0) + - colsample_bytree: Feature subsample ratio (default: 1.0) + - reg_alpha: L1 regularization (default: 0.0) + - reg_lambda: L2 regularization (default: 1.0) + - random_state: Random seed (default: 42) + - objective: Learning objective (default: "binary:logistic") + + Examples: + >>> # Default config + >>> strategy = XGBoostStrategy({}) + >>> strategy.n_estimators + 100 + + >>> # Custom config + >>> config = {"n_estimators": 50, "max_depth": 4, "learning_rate": 0.1} + >>> strategy = XGBoostStrategy(config) + >>> strategy.n_estimators + 50 + """ + # Extract hyperparameters (enforce single source of truth from YAML) + # No hardcoded defaults in Python - config must be complete + self.n_estimators = config["n_estimators"] + self.max_depth = config["max_depth"] + self.learning_rate = config["learning_rate"] + self.subsample = config["subsample"] + self.colsample_bytree = config["colsample_bytree"] + self.reg_alpha = config["reg_alpha"] + self.reg_lambda = config["reg_lambda"] + self.random_state = config["random_state"] + self.objective = config["objective"] + + # Create XGBClassifier estimator + self.classifier = xgb.XGBClassifier( + n_estimators=self.n_estimators, + max_depth=self.max_depth, + learning_rate=self.learning_rate, + subsample=self.subsample, + colsample_bytree=self.colsample_bytree, + reg_alpha=self.reg_alpha, + reg_lambda=self.reg_lambda, + random_state=self.random_state, + objective=self.objective, + ) + + # ======================================================================== + # ClassifierStrategy Protocol Methods + # ======================================================================== + + def fit(self, X: np.ndarray, y: np.ndarray) -> None: + """ + Fit XGBoost on embeddings. + + Args: + X: Embeddings array, shape (n_samples, n_features) + y: Labels array, shape (n_samples,) + + Notes: + No scaling is applied (matches Novo Nordisk methodology). + After fitting, the classes_ attribute is available. + + Examples: + >>> X_train = np.random.rand(100, 1280) + >>> y_train = np.array([0, 1] * 50) + >>> strategy.fit(X_train, y_train) + >>> strategy.classes_ + array([0, 1]) + """ + self.classifier.fit(X, y) + + def predict(self, X: np.ndarray) -> np.ndarray: + """ + Predict class labels. + + Args: + X: Embeddings array, shape (n_samples, n_features) + + Returns: + Predicted labels, shape (n_samples,) + + Examples: + >>> X_test = np.random.rand(20, 1280) + >>> predictions = strategy.predict(X_test) + >>> predictions.shape + (20,) + """ + result: np.ndarray = self.classifier.predict(X) + return result + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + """ + Predict class probabilities. + + Args: + X: Embeddings array, shape (n_samples, n_features) + + Returns: + Probability array, shape (n_samples, n_classes) + + Examples: + >>> X_test = np.random.rand(20, 1280) + >>> probs = strategy.predict_proba(X_test) + >>> probs.shape + (20, 2) + >>> np.allclose(probs.sum(axis=1), 1.0) + True + """ + result: np.ndarray = self.classifier.predict_proba(X) + return result + + def score(self, X: np.ndarray, y: np.ndarray) -> float: + """ + Return mean accuracy on test data. + + Args: + X: Embeddings array, shape (n_samples, n_features) + y: True labels, shape (n_samples,) + + Returns: + Mean accuracy score + + Examples: + >>> X_test = np.random.rand(20, 1280) + >>> y_test = np.array([0, 1] * 10) + >>> acc = strategy.score(X_test, y_test) + >>> 0.0 <= acc <= 1.0 + True + """ + result: float = self.classifier.score(X, y) + return result + + def get_params(self, deep: bool = True) -> dict[str, Any]: # noqa: ARG002 + """ + Get hyperparameters (sklearn API). + + Args: + deep: If True, return params for nested estimators (unused) + + Returns: + Dictionary of hyperparameters + + Examples: + >>> params = strategy.get_params() + >>> params['n_estimators'] + 100 + """ + return { + "type": "xgboost", + "n_estimators": self.n_estimators, + "max_depth": self.max_depth, + "learning_rate": self.learning_rate, + "subsample": self.subsample, + "colsample_bytree": self.colsample_bytree, + "reg_alpha": self.reg_alpha, + "reg_lambda": self.reg_lambda, + "random_state": self.random_state, + "objective": self.objective, + } + + @property + def classes_(self) -> np.ndarray: + """ + Class labels discovered during fit. + + Returns: + Array of class labels, shape (n_classes,) + + Raises: + AttributeError: If classifier not fitted + + Examples: + >>> strategy.fit(X_train, y_train) + >>> strategy.classes_ + array([0, 1]) + """ + result: np.ndarray = self.classifier.classes_ + return result + + # ======================================================================== + # Native Serialization Methods (XGBoost .xgb format) + # ======================================================================== + + def to_dict(self) -> dict[str, Any]: + """ + Serialize hyperparameters to dict (for JSON). + + Returns: + Dictionary with all hyperparameters and metadata. + Does NOT include fitted state - use save_model() for that. + + Examples: + >>> config = strategy.to_dict() + >>> config['type'] + 'xgboost' + >>> config['n_estimators'] + 100 + + >>> # Save to JSON + >>> import json + >>> json.dump(config, open("model_config.json", "w")) + """ + return { + "type": "xgboost", + "n_estimators": self.n_estimators, + "max_depth": self.max_depth, + "learning_rate": self.learning_rate, + "subsample": self.subsample, + "colsample_bytree": self.colsample_bytree, + "reg_alpha": self.reg_alpha, + "reg_lambda": self.reg_lambda, + "random_state": self.random_state, + "objective": self.objective, + } + + def to_arrays(self) -> dict[str, np.ndarray]: + """ + Extract fitted state as numpy arrays (for NPZ). + + Returns: + Empty dict (XGBoost uses native .xgb format, not NPZ arrays). + Required for ClassifierStrategy protocol compliance. + """ + return {} + + def save_model(self, path: str) -> None: + """ + Save fitted model to XGBoost native .xgb format. + + Args: + path: File path for .xgb model file + + Raises: + ValueError: If classifier not fitted + + Examples: + >>> strategy.fit(X_train, y_train) + >>> strategy.save_model("model.xgb") + """ + if not hasattr(self.classifier, "_Booster"): + raise ValueError("Classifier must be fitted before saving") + + self.classifier.save_model(path) + + @classmethod + def load_model(cls, path: str, config: dict[str, Any]) -> XGBoostStrategy: + """ + Load fitted model from XGBoost native .xgb format. + + Args: + path: File path to .xgb model file + config: Configuration dictionary with hyperparameters + + Returns: + XGBoostStrategy with loaded model + + Examples: + >>> # Load from .xgb + JSON + >>> import json + >>> config = json.load(open("model_config.json")) + >>> strategy = XGBoostStrategy.load_model("model.xgb", config) + >>> strategy.predict(X_test) + array([0, 1, 0, ...]) + """ + # Create unfitted classifier + strategy = cls(config) + + # Load fitted model from .xgb file + strategy.classifier.load_model(path) + + return strategy + + @classmethod + def from_dict( + cls, config: dict[str, Any], _arrays: dict[str, np.ndarray] | None = None + ) -> ClassifierStrategy: + """ + Create XGBoostStrategy from configuration dictionary. + + Args: + config: Configuration dictionary with hyperparameters + _arrays: Ignored for XGBoost (uses .xgb file instead) + + Returns: + Unfitted XGBoostStrategy instance (load_model needed for fitted state) + + Examples: + >>> config = {"type": "xgboost", "n_estimators": 50, "random_state": 42} + >>> strategy = XGBoostStrategy.from_dict(config) + >>> strategy.n_estimators + 50 + """ + return cls(config) diff --git a/src/antibody_training_esm/core/trainer.py b/src/antibody_training_esm/core/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e39d3e1071fd2cc112e9fa31ae33d5485fcbff --- /dev/null +++ b/src/antibody_training_esm/core/trainer.py @@ -0,0 +1,274 @@ +""" +Training Module + +Professional training pipeline for antibody classification models. +Includes cross-validation, embedding caching, and comprehensive evaluation. +""" + +import logging +from pathlib import Path +from typing import Any + +import hydra +import numpy as np +from omegaconf import DictConfig + +from antibody_training_esm.core.classifier import BinaryClassifier +from antibody_training_esm.core.config import LOG_SEPARATOR_WIDTH + +# Import and re-export from submodules +from antibody_training_esm.core.training.cache import ( + get_or_create_embeddings, + validate_embeddings, +) +from antibody_training_esm.core.training.metrics import ( + evaluate_model, + perform_cross_validation, + save_cv_results, +) +from antibody_training_esm.core.training.serialization import ( + load_config, + load_model_from_npz, + save_model, +) +from antibody_training_esm.data.loaders import load_data + +__all__ = [ + "validate_config", + "setup_logging", + "load_config", + "validate_embeddings", + "get_or_create_embeddings", + "evaluate_model", + "perform_cross_validation", + "save_cv_results", + "save_model", + "load_model_from_npz", + "train_pipeline", + "main", +] + + +from antibody_training_esm.models.config import TrainingPipelineConfig + + +def validate_config(config: dict[str, Any] | DictConfig) -> TrainingPipelineConfig: + """ + Validate config with Pydantic models. + + Args: + config: Raw dict or Hydra DictConfig + + Returns: + Validated TrainingPipelineConfig + + Raises: + ValidationError: If config is invalid + """ + if isinstance(config, DictConfig): + return TrainingPipelineConfig.from_hydra(config) + result: TrainingPipelineConfig = TrainingPipelineConfig.model_validate(config) + return result + + +def setup_logging(config: TrainingPipelineConfig) -> logging.Logger: + """ + Setup logging from Pydantic config. + + Args: + config: Validated TrainingPipelineConfig + + Returns: + Configured logger + """ + from hydra.core.hydra_config import HydraConfig + + log_level = getattr(logging, config.training.log_level.upper()) + log_file = config.training.log_file + + # Hydra-aware path resolution (same as before) + try: + hydra_cfg = HydraConfig.get() + output_dir = Path(hydra_cfg.runtime.output_dir) + log_path = output_dir / log_file + log_path.parent.mkdir(parents=True, exist_ok=True) + except (ValueError, AttributeError): + log_path = Path(log_file) + if not log_path.is_absolute(): + log_path = Path.cwd() / log_file + log_path.parent.mkdir(parents=True, exist_ok=True) + + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.FileHandler(log_path), logging.StreamHandler()], + force=True, + ) + + return logging.getLogger(__name__) + + +def train_pipeline(cfg: DictConfig) -> dict[str, Any]: + """Core training pipeline with Pydantic validation.""" + # Validate config (now returns Pydantic model) + config = validate_config(cfg) + + # Setup logging (accepts Pydantic model now) + logger = setup_logging(config) + + logger.info("Starting antibody classification training") + logger.info(f"Experiment: {config.experiment.name}") + + try: + X_train, y_train = load_data(config) + + logger.info(f"Loaded {len(X_train)} training samples") + + # Initialize classifier + classifier_params = { + "model_name": config.model.name, + "device": config.model.device, + "batch_size": config.model.batch_size, + "revision": config.model.revision, + # Classifier strategy params + "strategy": config.classifier.strategy, + "C": config.classifier.C, + "penalty": config.classifier.penalty, + "solver": config.classifier.solver, + "class_weight": config.classifier.class_weight, + "max_iter": config.classifier.max_iter, + "random_state": config.classifier.random_state, + "n_estimators": config.classifier.n_estimators, + "max_depth": config.classifier.max_depth, + "learning_rate": config.classifier.learning_rate, + } + + classifier = BinaryClassifier(classifier_params) + + # Get embeddings (cache_dir from config) + cache_dir = config.data.embeddings_cache_dir + X_train_embedded = get_or_create_embeddings( + X_train, classifier.embedding_extractor, cache_dir, "train", logger + ) + + # Convert labels to numpy array + y_train_array: np.ndarray = np.array(y_train) + + # Perform CV (returns CVResults Pydantic model) + cv_results = perform_cross_validation( + X_train_embedded, + y_train_array, + config, # Passing Pydantic model + logger, + ) + + # Save CV results + try: + from hydra.core.hydra_config import HydraConfig + + hydra_cfg = HydraConfig.get() + cv_output_dir = Path(hydra_cfg.runtime.output_dir) + experiment_name = config.experiment.name + logger.info(f"Saving CV results to Hydra output dir: {cv_output_dir}") + except (ValueError, AttributeError, ImportError): + cv_output_dir = config.training.model_save_dir + experiment_name = config.experiment.name + logger.info(f"Running without Hydra, saving CV results to {cv_output_dir}") + + save_cv_results(cv_results, cv_output_dir, experiment_name, logger) + + # Train final model + classifier.fit(X_train_embedded, y_train_array) + + # Evaluate (returns EvaluationMetrics Pydantic model) + train_results = evaluate_model( + classifier, + X_train_embedded, + y_train_array, + "Training", + list(config.training.metrics), # Cast to list for type safety + logger, + ) + + # Save model + if config.training.save_model: + # save_model expects config dict or object. + # We'll pass Pydantic config. + # Attach metrics to config for metadata saving + config.train_metrics = train_results.model_dump( + mode="json", exclude_none=True + ) + model_paths = save_model(classifier, config, logger) + else: + model_paths = {} + + return { + "train_metrics": train_results, + "cv_metrics": cv_results, + "config": config.model_dump(), # Convert back to dict for serialization + "model_paths": model_paths, + } + + except Exception as e: + logger.error(f"Training failed: {e}") + raise + + +@hydra.main(version_base=None, config_path="../conf", config_name="config") +def main(cfg: DictConfig) -> None: + """ + Hydra entry point for CLI - DO NOT call directly in tests + + This is the CLI entry point decorated with @hydra.main. It: + - Automatically parses command-line overrides + - Creates Hydra output directories + - Saves composed config to .hydra/config.yaml + - Delegates to train_pipeline() for core logic + + Usage: + # Default config + python -m antibody_training_esm.core.trainer + + # With overrides + python -m antibody_training_esm.core.trainer model.batch_size=16 + + # Multi-run sweep + python -m antibody_training_esm.core.trainer --multirun model=esm1v,esm2 + + Note: + Tests should call train_pipeline() directly, not this function. + This function is only for CLI usage with sys.argv parsing. + """ + logger = logging.getLogger(__name__) + logger.info(f"Starting training with Hydra (experiment: {cfg.experiment.name})") + + try: + # Call core training pipeline + results = train_pipeline(cfg) + + # Log final results (access Pydantic fields) + train_metrics = results["train_metrics"] + cv_metrics = results["cv_metrics"] + + logger.info("=" * LOG_SEPARATOR_WIDTH) + logger.info("TRAINING COMPLETE") + logger.info("=" * LOG_SEPARATOR_WIDTH) + logger.info(f"Train Accuracy: {train_metrics.accuracy:.4f}") + logger.info( + f"CV Accuracy: {cv_metrics.cv_accuracy['mean']:.4f} " + f"(+/- {cv_metrics.cv_accuracy['std'] * 2:.4f})" + ) + + if results.get("model_paths"): + logger.info(f"Model saved to: {results['model_paths']['pickle']}") + + logger.info("=" * LOG_SEPARATOR_WIDTH) + + except Exception as e: + logger.error(f"Training failed: {str(e)}") + raise + + +if __name__ == "__main__": + # Use Hydra main entry point (parses sys.argv automatically) + main() diff --git a/src/antibody_training_esm/core/training/__init__.py b/src/antibody_training_esm/core/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03fd8c2b696c14c1ef182ffef9552da90be4767f --- /dev/null +++ b/src/antibody_training_esm/core/training/__init__.py @@ -0,0 +1,3 @@ +from .cache import get_or_create_embeddings, validate_embeddings +from .metrics import evaluate_model, perform_cross_validation, save_cv_results +from .serialization import load_config, load_model_from_npz, save_model diff --git a/src/antibody_training_esm/core/training/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/core/training/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..084025abec75831a2c1ca4ddfc84d7e137029612 Binary files /dev/null and b/src/antibody_training_esm/core/training/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/training/__pycache__/cache.cpython-312.pyc b/src/antibody_training_esm/core/training/__pycache__/cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0bf9358f6736e8d02da6bd3079046b01522dbaa Binary files /dev/null and b/src/antibody_training_esm/core/training/__pycache__/cache.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/training/__pycache__/evaluation.cpython-312.pyc b/src/antibody_training_esm/core/training/__pycache__/evaluation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72eac9083a77e5e7eb33f62728f8eb2beafad1ff Binary files /dev/null and b/src/antibody_training_esm/core/training/__pycache__/evaluation.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/training/__pycache__/metrics.cpython-312.pyc b/src/antibody_training_esm/core/training/__pycache__/metrics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41b86c5da30292dde602eb9f300bebd04f00dc0f Binary files /dev/null and b/src/antibody_training_esm/core/training/__pycache__/metrics.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/training/__pycache__/serialization.cpython-312.pyc b/src/antibody_training_esm/core/training/__pycache__/serialization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9503134e7f0617663717af581346f2cb5e7b12a4 Binary files /dev/null and b/src/antibody_training_esm/core/training/__pycache__/serialization.cpython-312.pyc differ diff --git a/src/antibody_training_esm/core/training/cache.py b/src/antibody_training_esm/core/training/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..03d8fcf49b7781c7c1b6effcc8b3b22ba709bf89 --- /dev/null +++ b/src/antibody_training_esm/core/training/cache.py @@ -0,0 +1,201 @@ +""" +Embedding cache management. + +Handles loading, saving, and validating ESM embeddings to disk to avoid +redundant computation. +""" + +import hashlib +import logging +import os +import pickle # nosec B403 +from pathlib import Path +from typing import Any + +import numpy as np + +from antibody_training_esm.core.embeddings import ESMEmbeddingExtractor + + +def validate_embeddings( + embeddings: np.ndarray, + num_sequences: int, + logger: logging.Logger, + source: str = "cache", +) -> None: + """ + Validate embeddings are not corrupted. + + Args: + embeddings: Embedding array to validate + num_sequences: Expected number of sequences + logger: Logger instance + source: Where embeddings came from (for error messages) + + Raises: + ValueError: If embeddings are invalid (wrong shape, NaN, all zeros) + """ + # Check shape + if embeddings.shape[0] != num_sequences: + raise ValueError( + f"Embeddings from {source} have wrong shape: expected {num_sequences} sequences, " + f"got {embeddings.shape[0]}" + ) + + if len(embeddings.shape) != 2: + raise ValueError( + f"Embeddings from {source} must be 2D array, got shape {embeddings.shape}" + ) + + # Check for NaN values + if np.isnan(embeddings).any(): + nan_count = np.isnan(embeddings).sum() + raise ValueError( + f"Embeddings from {source} contain {nan_count} NaN values. " + "This indicates corrupted embeddings - cannot train on invalid data." + ) + + # Check for all-zero rows (corrupted/failed embeddings) + zero_rows = np.all(embeddings == 0, axis=1) + if zero_rows.any(): + zero_count = zero_rows.sum() + raise ValueError( + f"Embeddings from {source} contain {zero_count} all-zero rows. " + "This indicates corrupted embeddings from failed batch processing. " + "Delete the cache file and recompute." + ) + + logger.debug( + f"Embeddings validation passed: shape={embeddings.shape}, no NaN, no zero rows" + ) + + +def get_or_create_embeddings( + sequences: list[str], + embedding_extractor: ESMEmbeddingExtractor, + cache_path: str | Path, + dataset_name: str, + logger: logging.Logger, +) -> np.ndarray: + """ + Get embeddings from cache or create them + + Args: + sequences: List of protein sequences + embedding_extractor: ESM embedding extractor + cache_path: Directory for caching embeddings + dataset_name: Name of dataset (for cache filename) + logger: Logger instance + + Returns: + Array of embeddings + + Raises: + ValueError: If cached or computed embeddings are invalid + """ + # Ensure cache_path is string for os.path.join/os.makedirs compatibility + # (os.path supports Path in 3.6+, but for safety/consistency with type hint) + cache_path_str = str(cache_path) + + # Create a hash that includes model metadata to prevent cache collisions + # between different backbones (ESM-1v, ESM2, AntiBERTa, etc.) + sequences_str = "|".join(sequences) + cache_key_components = ( + f"{embedding_extractor.model_name}|" + f"{embedding_extractor.revision}|" + f"{embedding_extractor.max_length}|" + f"{sequences_str}" + ) + # Use SHA-256 (non-cryptographic usage) to satisfy security scanners and + # prevent weak-hash findings while keeping deterministic cache keys. + sequences_hash = hashlib.sha256(cache_key_components.encode()).hexdigest()[:12] + cache_file = os.path.join( + cache_path_str, f"{dataset_name}_{sequences_hash}_embeddings.pkl" + ) + + if os.path.exists(cache_file): + logger.info(f"Loading cached embeddings from {cache_file}") + with open(cache_file, "rb") as f: + cached_data_raw = pickle.load(f) # nosec B301 - Hash-validated local cache + + # Validate loaded data type and structure + if not isinstance(cached_data_raw, dict): + logger.warning( + f"Invalid cache file format (expected dict, got {type(cached_data_raw).__name__}). " + "Recomputing embeddings..." + ) + elif ( + "embeddings" not in cached_data_raw + or "sequences_hash" not in cached_data_raw + ): + missing_keys = {"embeddings", "sequences_hash"} - set( + cached_data_raw.keys() + ) + logger.warning( + f"Corrupt cache file (missing keys: {missing_keys}). " + "Recomputing embeddings..." + ) + else: + cached_data: dict[str, Any] = cached_data_raw + + # Verify the cached sequences and model metadata match exactly + # This prevents ESM2 from reusing ESM-1v embeddings, etc. + model_metadata_matches = ( + cached_data.get("model_name") == embedding_extractor.model_name + and cached_data.get("revision") == embedding_extractor.revision + and cached_data.get("max_length") == embedding_extractor.max_length + ) + + if ( + len(cached_data["embeddings"]) == len(sequences) + and cached_data["sequences_hash"] == sequences_hash + and model_metadata_matches + ): + logger.info( + f"Using cached embeddings for {len(sequences)} sequences " + f"(model: {embedding_extractor.model_name}, hash: {sequences_hash})" + ) + embeddings_result: np.ndarray = cached_data["embeddings"] + + # Validate cached embeddings before using them + validate_embeddings( + embeddings_result, len(sequences), logger, source="cache" + ) + + return embeddings_result + elif not model_metadata_matches: + logger.warning( + f"Cached embeddings model mismatch " + f"(cached: {cached_data.get('model_name')}, " + f"current: {embedding_extractor.model_name}). " + "Recomputing..." + ) + else: + logger.warning("Cached embeddings hash mismatch, recomputing...") + + logger.info(f"Computing embeddings for {len(sequences)} sequences...") + embeddings = embedding_extractor.extract_batch_embeddings(sequences) + + # Validate newly computed embeddings before caching + validate_embeddings(embeddings, len(sequences), logger, source="computed") + + # Cache the embeddings with metadata for verification + # Include model metadata to prevent cache collisions between different backbones + os.makedirs(cache_path_str, exist_ok=True) + cache_data = { + "embeddings": embeddings, + "sequences_hash": sequences_hash, + "num_sequences": len(sequences), + "dataset_name": dataset_name, + "model_name": embedding_extractor.model_name, + "revision": embedding_extractor.revision, + "max_length": embedding_extractor.max_length, + } + with open(cache_file, "wb") as f: + pickle.dump(cache_data, f) + logger.info( + f"Cached embeddings to {cache_file} " + f"(model: {embedding_extractor.model_name}, hash: {sequences_hash})" + ) + + return embeddings diff --git a/src/antibody_training_esm/core/training/metrics.py b/src/antibody_training_esm/core/training/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2ff4810a9be43665c70cffbd5586726c88f0cbb2 --- /dev/null +++ b/src/antibody_training_esm/core/training/metrics.py @@ -0,0 +1,214 @@ +""" +Evaluation metrics and cross-validation logic. + +Computes accuracy, F1, ROC-AUC, and other classification metrics. +Handles logging and result storage. +""" + +import logging +from collections.abc import Sequence +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import yaml +from sklearn.metrics import classification_report +from sklearn.model_selection import KFold, StratifiedKFold, cross_validate + +from antibody_training_esm.core.classifier import BinaryClassifier +from antibody_training_esm.core.config import DEFAULT_BATCH_SIZE +from antibody_training_esm.models.artifact import CVResults, EvaluationMetrics + +if TYPE_CHECKING: + from antibody_training_esm.models.config import TrainingPipelineConfig + + +def evaluate_model( + classifier: BinaryClassifier, + X: np.ndarray, + y: np.ndarray, + dataset_name: str, + _metrics: Sequence[str] | set[str], + logger: logging.Logger, +) -> EvaluationMetrics: + """ + Evaluate model performance + + Args: + classifier: Trained classifier + X: Embeddings array + y: Labels array + dataset_name: Name of dataset being evaluated + _metrics: List/Set of metrics to compute (ignored, computes all standard metrics) + logger: Logger instance + + Returns: + EvaluationMetrics Pydantic model + """ + logger.info(f"Evaluating model on {dataset_name} set") + + # Get predictions + y_pred = classifier.predict(X) + y_pred_proba = classifier.predict_proba(X) + + # Create metrics object using Pydantic factory + eval_metrics = EvaluationMetrics.from_sklearn_metrics( + y_true=y, + y_pred=y_pred, + y_proba=y_pred_proba, + dataset_name=dataset_name, + ) + + # Log results + logger.info(f"{dataset_name} Results:") + logger.info(f" Accuracy: {eval_metrics.accuracy:.4f}") + if eval_metrics.precision is not None: + logger.info(f" Precision: {eval_metrics.precision:.4f}") + if eval_metrics.recall is not None: + logger.info(f" Recall: {eval_metrics.recall:.4f}") + if eval_metrics.f1 is not None: + logger.info(f" F1: {eval_metrics.f1:.4f}") + if eval_metrics.roc_auc is not None: + logger.info(f" ROC-AUC: {eval_metrics.roc_auc:.4f}") + + # Log classification report (useful for detailed class-wise metrics) + logger.info(f"\n{dataset_name} Classification Report:") + logger.info(f"\n{classification_report(y, y_pred)}") + + return eval_metrics + + +def perform_cross_validation( + X: np.ndarray, + y: np.ndarray, + config: "TrainingPipelineConfig | dict[str, Any]", + logger: logging.Logger, +) -> CVResults: + """ + Perform cross-validation + + Args: + X: Embeddings array + y: Labels array + config: Configuration (Pydantic object or legacy dict) + logger: Logger instance + + Returns: + CVResults Pydantic model + """ + from antibody_training_esm.models.config import TrainingPipelineConfig + + # Extract parameters based on config type + if isinstance(config, TrainingPipelineConfig): + cv_folds = config.training.n_splits + random_state = config.training.random_state + stratify = config.training.stratify + model_name = config.model.name + device = config.model.device + batch_size = config.model.batch_size + + clf_params = config.classifier.model_dump() + else: + training_conf = config.get("training", {}) + classifier_conf = config.get("classifier", {}) + + cv_folds = training_conf.get("n_splits", classifier_conf.get("cv_folds", 10)) + stratify = training_conf.get("stratify", True) + random_state = training_conf.get( + "random_state", classifier_conf.get("random_state", 42) + ) + + model_cfg = config.get("model", {}) + model_name = model_cfg.get("name", "") + device = model_cfg.get("device", "cpu") + batch_size = training_conf.get( + "batch_size", model_cfg.get("batch_size", DEFAULT_BATCH_SIZE) + ) + clf_params = classifier_conf.copy() + + logger.info(f"Performing {cv_folds}-fold cross-validation") + + # Setup cross-validation + if stratify: + cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=random_state) + else: + cv = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state) + + # Create a new classifier instance for CV + cv_params = clf_params.copy() + cv_params["model_name"] = model_name + cv_params["device"] = device + cv_params["batch_size"] = batch_size + + cv_classifier = BinaryClassifier(cv_params) + + # Define metrics to compute + scoring = { + "accuracy": "accuracy", + "f1": "f1", + "precision": "precision", + "recall": "recall", + "roc_auc": "roc_auc", + } + + # Perform cross-validation using cross_validate (more efficient than multiple cross_val_score calls) + cv_scores = cross_validate( + cv_classifier, X, y, cv=cv, scoring=scoring, return_train_score=False + ) + + # Create CVResults object using Pydantic factory + cv_results = CVResults.from_sklearn_cv_results(cv_scores, n_splits=cv_folds) + + # Log results + logger.info("Cross-validation Results:") + logger.info( + f" Accuracy: {cv_results.cv_accuracy['mean']:.4f} (+/- {cv_results.cv_accuracy['std'] * 2:.4f})" + ) + if cv_results.cv_f1: + logger.info( + f" F1: {cv_results.cv_f1['mean']:.4f} (+/- {cv_results.cv_f1['std'] * 2:.4f})" + ) + if cv_results.cv_roc_auc: + logger.info( + f" ROC-AUC: {cv_results.cv_roc_auc['mean']:.4f} (+/- {cv_results.cv_roc_auc['std'] * 2:.4f})" + ) + + return cv_results + + +def save_cv_results( + cv_results: CVResults, + output_dir: Path, + experiment_name: str, + logger: logging.Logger, +) -> None: + """ + Save cross-validation results to structured YAML file. + + Args: + cv_results: CVResults Pydantic model + output_dir: Directory to save CV results file + experiment_name: Name of the experiment + logger: Logger instance + """ + # Ensure output directory exists + output_dir.mkdir(parents=True, exist_ok=True) + + cv_file = output_dir / "cv_results.yaml" + + # Use Pydantic's model_dump for clean serialization + results_dict = cv_results.model_dump(mode="json") + + with open(cv_file, "w") as f: + yaml.dump( + { + "experiment": experiment_name, + "timestamp": datetime.now().isoformat(), + "cv_metrics": results_dict, + }, + f, + default_flow_style=False, + ) + + logger.info(f"CV results saved to {cv_file}") diff --git a/src/antibody_training_esm/core/training/serialization.py b/src/antibody_training_esm/core/training/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..ae377454d2bc96f704aae799a138aacd68825a78 --- /dev/null +++ b/src/antibody_training_esm/core/training/serialization.py @@ -0,0 +1,215 @@ +""" +Model serialization utilities. + +Handles saving/loading models in dual format (pickle for dev, NPZ+JSON for production). +Manages configuration loading and directory structure. +""" + +import json +import logging +import pickle # nosec B403 +from typing import TYPE_CHECKING, Any, cast + +import numpy as np +import yaml + +from antibody_training_esm.core.classifier import BinaryClassifier +from antibody_training_esm.core.directory_utils import get_hierarchical_model_dir +from antibody_training_esm.models.artifact import ModelArtifactMetadata + +if TYPE_CHECKING: + from antibody_training_esm.models.config import TrainingPipelineConfig + + +def load_config(config_path: str) -> dict[str, Any]: + """ + Load configuration from YAML file + + Args: + config_path: Path to YAML configuration file + + Returns: + Configuration dictionary + + Raises: + FileNotFoundError: If config file doesn't exist + ValueError: If YAML is invalid + """ + try: + with open(config_path) as f: + config: dict[str, Any] = yaml.safe_load(f) + return config + except FileNotFoundError: + raise FileNotFoundError( + f"Config file not found: {config_path}\n" + "Please create it or specify a valid path with --config" + ) from None + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML in config file {config_path}: {e}") from e + + +def save_model( + classifier: BinaryClassifier, + config: "TrainingPipelineConfig | dict[str, Any]", + logger: logging.Logger, +) -> dict[str, str]: + """ + Save trained model in dual format (pickle + NPZ+JSON) + + Models are saved in hierarchical directory structure: + {model_save_dir}/{model_shortname}/{classifier_type}/{model_name}.* + + Example: + experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl + + Args: + classifier: Trained classifier + config: Configuration dictionary or Pydantic model + logger: Logger instance + + Returns: + Dictionary with paths to saved files: + { + "pickle": "experiments/checkpoints/esm1v/logreg/model.pkl", + "npz": "experiments/checkpoints/esm1v/logreg/model.npz", + "config": "experiments/checkpoints/esm1v/logreg/model_config.json" + } + Empty dict if saving is disabled. + """ + from antibody_training_esm.models.config import TrainingPipelineConfig + + # Helper to extract config values regardless of type (Dict vs Pydantic) + if isinstance(config, TrainingPipelineConfig): + if not config.training.save_model: + return {} + model_name = config.training.model_name + base_save_dir = config.training.model_save_dir + model_shortname = config.model.name + classifier_config = config.classifier.model_dump() + train_metrics = getattr(config, "train_metrics", None) + else: + if not config["training"]["save_model"]: + return {} + model_name = config["training"]["model_name"] + base_save_dir = config["training"]["model_save_dir"] + model_shortname = config["model"]["name"] + classifier_config = config["classifier"] + train_metrics = config.get("train_metrics") + + # Generate hierarchical directory path + hierarchical_dir = get_hierarchical_model_dir( + str(base_save_dir), + model_shortname, + classifier_config, + ) + hierarchical_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Using hierarchical model directory: {hierarchical_dir}") + + base_path = hierarchical_dir / model_name + + # Format 1: Pickle checkpoint (research/debugging) + pickle_path = f"{base_path}.pkl" + with open(pickle_path, "wb") as f: + pickle.dump(classifier, f) + logger.info(f"Saved pickle checkpoint: {pickle_path}") + + # Format 2: Strategy-specific production serialization + # Use duck typing to detect serialization method + saved_paths = {"pickle": str(pickle_path)} + + if hasattr(classifier.classifier, "save_model"): + # XGBoost native .xgb format (pickle-free) + xgb_path = f"{base_path}.xgb" + classifier.classifier.save_model(str(xgb_path)) + logger.info(f"Saved XGBoost native model: {xgb_path}") + saved_paths["xgb"] = str(xgb_path) + elif hasattr(classifier.classifier, "to_arrays"): + # LogReg NPZ format (sklearn arrays) + npz_path = f"{base_path}.npz" + arrays = classifier.classifier.to_arrays() + np.savez(npz_path, **cast(dict[str, Any], arrays)) + logger.info(f"Saved NPZ arrays: {npz_path}") + saved_paths["npz"] = str(npz_path) + else: + # Fallback: legacy LogReg direct attribute access + # Cast to Any because protocol doesn't enforce LogReg attributes + inner_clf = cast(Any, classifier.classifier) + npz_path = f"{base_path}.npz" + np.savez( + npz_path, + coef=inner_clf.coef_, + intercept=inner_clf.intercept_, + classes=inner_clf.classes_, + n_features_in=np.array([inner_clf.n_features_in_]), + n_iter=inner_clf.n_iter_, + ) + logger.info(f"Saved NPZ arrays (legacy): {npz_path}") + saved_paths["npz"] = str(npz_path) + + # Format 3: JSON metadata (Pydantic) + json_path = f"{base_path}_config.json" + + # Construct metadata from classifier (Pydantic handles serialization) + metadata = ModelArtifactMetadata.from_classifier(classifier) + + # Add training metrics if available + if train_metrics: + metadata.training_metrics = train_metrics + + # Save as JSON (Pydantic handles type conversion) + with open(json_path, "w") as f: + # model_dump(mode='json') handles decimal/float serialization + json.dump(metadata.model_dump(mode="json"), f, indent=2) + + logger.info(f"Saved JSON config: {json_path}") + saved_paths["config"] = str(json_path) + + logger.info(f"Model saved successfully ({metadata.model_type} format)") + return saved_paths + + +def load_model_from_npz(npz_path: str, json_path: str) -> BinaryClassifier: + """ + Load model from NPZ+JSON format (production deployment) + + Args: + npz_path: Path to .npz file with arrays + json_path: Path to .json file with metadata + + Returns: + Reconstructed BinaryClassifier instance + + Notes: + This function enables production deployment without pickle files. + It reconstructs a fully functional BinaryClassifier from NPZ+JSON format. + Uses strict Pydantic validation for metadata. + """ + # Load arrays + arrays = np.load(npz_path) + coef = arrays["coef"] + intercept = arrays["intercept"] + classes = arrays["classes"] + n_features_in = int(arrays["n_features_in"][0]) + n_iter = arrays["n_iter"] + + # Load metadata (Pydantic validates) + with open(json_path) as f: + metadata_dict = json.load(f) + + metadata = ModelArtifactMetadata.model_validate(metadata_dict) + + # Construct BinaryClassifier from metadata (Pydantic handles types) + params = metadata.to_classifier_params() + classifier = BinaryClassifier(params) + + # Restore fitted LogisticRegression state + # Cast to Any because protocol doesn't enforce LogReg attributes + inner_clf = cast(Any, classifier.classifier) + inner_clf.classifier.coef_ = coef + inner_clf.classifier.intercept_ = intercept + inner_clf.classifier.classes_ = classes + inner_clf.classifier.n_features_in_ = n_features_in + inner_clf.classifier.n_iter_ = n_iter + classifier.is_fitted = True + + return classifier diff --git a/src/antibody_training_esm/data/__init__.py b/src/antibody_training_esm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..029c2e6b277cd1c70b96d4439feca19ed458d38e --- /dev/null +++ b/src/antibody_training_esm/data/__init__.py @@ -0,0 +1,26 @@ +""" +Data Module + +Professional data loading and preprocessing utilities: +- Hugging Face dataset integration +- Local CSV file loading +- Embedding preprocessing pipelines +""" + +from antibody_training_esm.data.loaders import ( + load_data, + load_hf_dataset, + load_local_data, + load_preprocessed_data, + preprocess_raw_data, + store_preprocessed_data, +) + +__all__ = [ + "load_data", + "load_hf_dataset", + "load_local_data", + "load_preprocessed_data", + "preprocess_raw_data", + "store_preprocessed_data", +] diff --git a/src/antibody_training_esm/data/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/data/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c87f2a6a14ce1bd56930d29ae0eb5929d50741f Binary files /dev/null and b/src/antibody_training_esm/data/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/data/__pycache__/loaders.cpython-312.pyc b/src/antibody_training_esm/data/__pycache__/loaders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..728228567a00cbe7bf67e3f7d9092c11fcbfe4f4 Binary files /dev/null and b/src/antibody_training_esm/data/__pycache__/loaders.cpython-312.pyc differ diff --git a/src/antibody_training_esm/data/loaders.py b/src/antibody_training_esm/data/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..5286f62dee7d2df55d0d4e79b6d9806a5e80b6b6 --- /dev/null +++ b/src/antibody_training_esm/data/loaders.py @@ -0,0 +1,238 @@ +""" +Data Loading Module + +Professional data loading utilities for antibody sequence datasets. +Supports Hugging Face datasets, local CSV files, and preprocessing pipelines. +""" + +import logging +import pickle # nosec B403 - Used only for local trusted data (preprocessed datasets) +from collections.abc import Sequence +from pathlib import Path +from typing import TYPE_CHECKING, Any, Protocol, cast + +import numpy as np +import pandas as pd + +# HuggingFace datasets library lacks complete type stubs +# See: https://github.com/huggingface/datasets/issues/3426 +from datasets import load_dataset # type: ignore[attr-defined] + +if TYPE_CHECKING: + from antibody_training_esm.models.config import TrainingPipelineConfig + +logger = logging.getLogger(__name__) +type Label = int | float | bool | str + + +class EmbeddingExtractor(Protocol): + """Protocol for embedding extractors""" + + def extract_batch_embeddings(self, sequences: Sequence[str]) -> np.ndarray: + """Extract embeddings for a batch of sequences""" + ... # lgtm[py/ineffectual-statement] + + def embed_sequence(self, sequence: str) -> np.ndarray: + """Extract embedding for a single sequence""" + ... # lgtm[py/ineffectual-statement] + + +def preprocess_raw_data( + X: Sequence[str], + y: Sequence[Label], + embedding_extractor: EmbeddingExtractor, +) -> tuple[np.ndarray, np.ndarray]: + """ + Embed sequences using ESM model + + Args: + X: List of input protein sequences (strings) + y: List or array of labels + embedding_extractor: Instance with 'embed_sequence' or 'extract_batch_embeddings' method + + Returns: + X_embedded: Embedded input sequences + y: Labels as numpy array + + Notes: + No StandardScaler used - matches Novo Nordisk methodology + """ + logger.info(f"Embedding {len(X)} sequences...") + + # Try to use batch embedding if available (more efficient) + if hasattr(embedding_extractor, "extract_batch_embeddings"): + X_embedded = embedding_extractor.extract_batch_embeddings(X) + else: + X_embedded = np.array([embedding_extractor.embed_sequence(seq) for seq in X]) + + return X_embedded, np.array(y) + + +def store_preprocessed_data( + X: Sequence[str] | None = None, + y: Sequence[Label] | None = None, + X_embedded: np.ndarray | None = None, + filename: str | None = None, +) -> None: + """ + Store preprocessed data to pickle file + + Args: + X: Raw sequences (optional) + y: Labels (optional) + X_embedded: Embedded data (optional) + filename: Output file path (required) + + Raises: + ValueError: If filename is not provided + """ + if filename is None: + raise ValueError("filename is required") + + data: dict[str, Sequence[str] | Sequence[Label] | np.ndarray] = {} + if X_embedded is not None: + data["X_embedded"] = X_embedded + if X is not None: + data["X"] = X + if y is not None: + data["y"] = y + + with open(filename, "wb") as f: + pickle.dump(data, f) + + +def load_preprocessed_data( + filename: str, +) -> dict[str, list[str] | list[Label] | np.ndarray]: + """ + Load preprocessed data from pickle file + + Args: + filename: Path to pickle file + + Returns: + Dictionary with keys: 'X', 'y', and/or 'X_embedded' + """ + with open(filename, "rb") as f: + data = cast(dict[str, list[str] | list[Label] | np.ndarray], pickle.load(f)) # nosec B301 - Loading our own preprocessed dataset from local file + return data + + +def load_hf_dataset( + dataset_name: str, + split: str, + text_column: str, + label_column: str, + revision: str = "main", +) -> tuple[list[str], list[Label]]: + """ + Load dataset from Hugging Face datasets library + + Args: + dataset_name: Name of the dataset to load + split: Which split to load (e.g., 'train', 'test', 'validation') + text_column: Name of the column containing the sequences + label_column: Name of the column containing the labels + revision: HuggingFace dataset revision (commit SHA or branch name) for reproducibility + + Returns: + X: List of input sequences + y: List of labels + """ + dataset = load_dataset( + dataset_name, + split=split, + revision=revision, # nosec B615 - Pinned to specific version for scientific reproducibility + ) + X = list(dataset[text_column]) + y = cast(list[Label], list(dataset[label_column])) + + return X, y + + +def load_local_data( + file_path: str | Path, text_column: str, label_column: str +) -> tuple[list[str], list[Label]]: + """ + Load training data from local CSV file + + Args: + file_path: Path to the local data file (CSV format) + text_column: Name of the column containing the sequences + label_column: Name of the column containing the labels + + Returns: + X: List of input sequences + y: List of labels + + Raises: + ValueError: If required columns are missing from CSV + """ + train_df = pd.read_csv(file_path, comment="#") # Handle comment lines in CSV + + # Validate required columns exist + available_columns = list(train_df.columns) + if text_column not in train_df.columns: + raise ValueError( + f"Sequence column '{text_column}' not found in {file_path}. " + f"Available columns: {available_columns}" + ) + if label_column not in train_df.columns: + raise ValueError( + f"Label column '{label_column}' not found in {file_path}. " + f"Available columns: {available_columns}" + ) + + X_train = train_df[text_column].tolist() + y_train = cast(list[Label], train_df[label_column].tolist()) + + return X_train, y_train + + +def load_data( + config: "dict[str, Any] | TrainingPipelineConfig", +) -> tuple[list[str], list[Label]]: + """ + Load training data from Pydantic config or legacy dict. + + Args: + config: Validated TrainingPipelineConfig or dict (legacy) + + Returns: + (sequences, labels) + """ + from antibody_training_esm.models.config import TrainingPipelineConfig + + if isinstance(config, TrainingPipelineConfig): + train_file = config.data.train_file + return load_local_data(train_file, "sequence", "label") + + # Handle Legacy Dict Config + data_config = config.get("data", {}) + + # Simplified logic: If 'train_file' exists, use local loader directly + # This matches the new schema which only supports local files via 'train_file' + if "train_file" in data_config: + # Legacy configs might specify columns, or we default to 'sequence'/'label' + # The new schema assumes 'sequence' and 'label' columns are present + # For backward compatibility, we check if keys exist, else default + seq_col = data_config.get("sequence_column", "sequence") + label_col = data_config.get("label_column", "label") + return load_local_data(data_config["train_file"], seq_col, label_col) + + # Keep existing logic for "source" key if present (strictly legacy/HF support) + if data_config.get("source") == "hf": + return load_hf_dataset( + dataset_name=data_config["dataset_name"], + split=data_config["train_split"], + text_column=data_config["sequence_column"], + label_column=data_config["label_column"], + ) + elif data_config.get("source") == "local": + return load_local_data( + data_config["train_file"], + text_column=data_config["sequence_column"], + label_column=data_config["label_column"], + ) + else: + raise ValueError(f"Unknown data source configuration: {data_config}") diff --git a/src/antibody_training_esm/datasets/README.md b/src/antibody_training_esm/datasets/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ddd3445c7c403f5a626e4342af6d48ee47502cf9 --- /dev/null +++ b/src/antibody_training_esm/datasets/README.md @@ -0,0 +1,204 @@ +# Dataset Loaders + +**IMPORTANT**: These classes are for **LOADING** preprocessed data, NOT for running preprocessing pipelines. + +## Architecture Overview + +This codebase maintains a clear separation between: + +1. **Preprocessing Scripts** (`preprocessing/`) - Creates fragment files +2. **Dataset Loaders** (`src/antibody_training_esm/datasets/`) - Loads fragment files for training + +``` +preprocessing/ src/antibody_training_esm/datasets/ +ā”œā”€ā”€ jain/ ā”œā”€ā”€ jain.py +│ └── step2_preprocess_p5e_s2.py │ (loads preprocessed data) +│ (CREATES fragments) │ +ā”œā”€ā”€ harvey/ ā”œā”€ā”€ harvey.py +│ └── step2_extract_fragments.py │ (loads preprocessed data) +│ (CREATES fragments) │ +ā”œā”€ā”€ shehata/ ā”œā”€ā”€ shehata.py +│ └── step2_extract_fragments.py │ (loads preprocessed data) +│ (CREATES fragments) │ +└── boughter/ └── boughter.py + └── stage2_stage3_annotation_qc.py (loads preprocessed data) + (CREATES fragments) +``` + +## Single Source of Truth (SSOT) + +**Preprocessing scripts** in `preprocessing/` are the **canonical source of truth** for: +- Data transformation logic +- Filtering rules +- Quality control +- Fragment generation +- ANARCI annotation + +**Dataset loaders** in `src/antibody_training_esm/datasets/` are **abstractions** for: +- Loading preprocessed fragment files +- Providing consistent API across datasets +- Validation and statistics +- Integration with training pipelines + +## Usage + +### Preprocessing (ONE-TIME, run scripts) + +```bash +# Jain dataset +python preprocessing/jain/step2_preprocess_p5e_s2.py + +# Harvey dataset +python preprocessing/harvey/step2_extract_fragments.py + +# Shehata dataset +python preprocessing/shehata/step2_extract_fragments.py + +# Boughter dataset +python preprocessing/boughter/stage2_stage3_annotation_qc.py +``` + +These scripts create fragment CSV files in: +- `data/test//fragments/` +- `data/train//fragments/` + +### Loading (TRAINING, use dataset classes) + +```python +# Option 1: Use dataset class +from antibody_training_esm.datasets import JainDataset + +dataset = JainDataset() +df = dataset.load_data(stage="parity") # Loads preprocessed 86-antibody set +print(f"Loaded {len(df)} sequences") + +# Option 2: Use convenience function +from antibody_training_esm.datasets import load_jain_data + +df = load_jain_data(stage="parity") +print(f"Loaded {len(df)} sequences") +``` + +## Available Dataset Loaders + +### JainDataset +- **Source**: `data/test/jain/processed/` +- **Fragments**: `data/test/jain/fragments/` +- **Preprocessing**: `preprocessing/jain/step2_preprocess_p5e_s2.py` +- **Characteristics**: + - 137 → 116 → 86 antibodies (Novo parity) + - PSR/AC-SINS filtering + - 16 fragment types (VH + VL) + +### HarveyDataset +- **Source**: `data/test/harvey/raw/` +- **Fragments**: `data/test/harvey/fragments/` +- **Preprocessing**: `preprocessing/harvey/step2_extract_fragments.py` +- **Characteristics**: + - 141,474 nanobody sequences (VHH only) + - 6 fragment types (nanobody-specific) + - IMGT position extraction + +### ShehataDataset +- **Source**: `data/test/shehata/raw/shehata-mmc2.xlsx` +- **Fragments**: `data/test/shehata/fragments/` +- **Preprocessing**: `preprocessing/shehata/step2_extract_fragments.py` +- **Characteristics**: + - 398 HIV antibodies + - PSR threshold-based labeling (98.24th percentile) + - 16 fragment types (VH + VL) + +### BoughterDataset +- **Source**: `data/train/boughter/raw/` (DNA FASTA files) +- **Fragments**: `data/train/boughter/annotated/` +- **Preprocessing**: `preprocessing/boughter/stage2_stage3_annotation_qc.py` +- **Characteristics**: + - Mouse antibodies (6 subsets) + - DNA translation required + - Novo flagging strategy (0/1-3/4+) + - 16 fragment types (VH + VL) + +## Design Principles + +### Single Responsibility (SRP) +- Each dataset loader handles ONE dataset +- Preprocessing scripts handle ONE pipeline +- No overlap in responsibilities + +### Open/Closed Principle (OCP) +- New datasets can be added by extending `AntibodyDataset` +- Existing code doesn't need modification +- Preprocessing scripts remain independent + +### Dependency Inversion (DIP) +- Training code depends on `AntibodyDataset` abstraction +- Not on specific dataset implementations +- Preprocessing scripts are independent + +## Why This Architecture? + +### Benefits + +1. **Clear Separation**: Preprocessing ≠ Loading +2. **Single Source of Truth**: Preprocessing scripts are authoritative +3. **Bit-for-Bit Parity**: Can validate new vs old outputs +4. **No Rewrites**: Preprocessing logic stays in scripts +5. **Professional Structure**: Industry-standard ML organization + +### What This Architecture Prevents + +āŒ **DON'T DO THIS**: +```python +# WRONG: Trying to preprocess from dataset class +dataset = JainDataset() +# There is NO process() method - it was intentionally removed! +# Dataset loaders are for LOADING preprocessed data only. +``` + +āœ… **DO THIS INSTEAD**: + +**Step 1: Run preprocessing script ONCE to CREATE fragment files:** +```bash +python preprocessing/jain/step2_preprocess_p5e_s2.py +``` + +**Step 2: LOAD the preprocessed data in your training code:** +```python +from antibody_training_esm.datasets import load_jain_data + +df = load_jain_data(stage="parity") # Fast, correct +``` + +## Future Work (Phase 4+) + +The base class (`AntibodyDataset`) provides infrastructure for: +- `annotate_sequence()` - ANARCI annotation +- `create_fragments()` - Fragment generation +- `validate_sequences()` - Quality control + +These methods are available for: +- Building NEW preprocessing pipelines +- Prototyping dataset variations +- Research experiments + +But they should **NOT** be used to replace the canonical preprocessing scripts. + +## Questions? + +- **"Should I use dataset classes or preprocessing scripts?"** + - **Preprocessing**: Use scripts in `preprocessing/` + - **Training**: Use dataset classes in `src/antibody_training_esm/datasets/` + +- **"Can I modify preprocessing logic in dataset classes?"** + - **No**. Preprocessing logic belongs in `preprocessing/` scripts. + - Dataset classes are for **loading**, not **creating** data. + +- **"How do I add a new dataset?"** + 1. Create preprocessing script in `preprocessing//` + 2. Create dataset loader in `src/antibody_training_esm/datasets/.py` + 3. Extend `AntibodyDataset` base class + 4. Implement `load_data()` method + +--- + +**Remember**: Preprocessing scripts are the **SSOT**. Dataset loaders are **abstractions** for training. diff --git a/src/antibody_training_esm/datasets/__init__.py b/src/antibody_training_esm/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adc4a52f0b1caeed84b2154279941be58593298f --- /dev/null +++ b/src/antibody_training_esm/datasets/__init__.py @@ -0,0 +1,47 @@ +""" +Dataset Loaders Module + +Professional dataset loaders following Open/Closed Principle. + +IMPORTANT: These classes LOAD preprocessed data, they do NOT run preprocessing pipelines. +For preprocessing, use the scripts in: preprocessing// + +This module provides: +- AntibodyDataset: Abstract base class for all dataset loaders +- JainDataset: Jain 2017 therapeutic antibody dataset loader +- HarveyDataset: Harvey nanobody polyreactivity dataset loader +- ShehataDataset: Shehata HIV antibody dataset loader +- BoughterDataset: Boughter mouse antibody dataset loader + +Each dataset class: +1. Implements dataset-specific loading logic for PREPROCESSED data +2. Provides common utilities (validation, statistics) +3. Can be extended without modifying existing code (OCP) + +Example usage: + >>> from antibody_training_esm.datasets import HarveyDataset + >>> dataset = HarveyDataset() + >>> df = dataset.load_data() + >>> print(f"Loaded {len(df)} sequences") +""" + +from .base import AntibodyDataset +from .boughter import BoughterDataset, load_boughter_data +from .harvey import HarveyDataset, load_harvey_data +from .jain import JainDataset, load_jain_data +from .shehata import ShehataDataset, load_shehata_data + +__all__ = [ + # Base class + "AntibodyDataset", + # Concrete dataset loader classes + "JainDataset", + "HarveyDataset", + "ShehataDataset", + "BoughterDataset", + # Convenience functions for loading + "load_jain_data", + "load_harvey_data", + "load_shehata_data", + "load_boughter_data", +] diff --git a/src/antibody_training_esm/datasets/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/datasets/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66d5002578644f6dcb9c59672649dedcb2c29baf Binary files /dev/null and b/src/antibody_training_esm/datasets/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/datasets/__pycache__/base.cpython-312.pyc b/src/antibody_training_esm/datasets/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb26af86e54d49647a6e00335d045ed7e32e6d92 Binary files /dev/null and b/src/antibody_training_esm/datasets/__pycache__/base.cpython-312.pyc differ diff --git a/src/antibody_training_esm/datasets/__pycache__/boughter.cpython-312.pyc b/src/antibody_training_esm/datasets/__pycache__/boughter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0988faf9ca9e92645af7037058b8b2f8461d7a10 Binary files /dev/null and b/src/antibody_training_esm/datasets/__pycache__/boughter.cpython-312.pyc differ diff --git a/src/antibody_training_esm/datasets/__pycache__/default_paths.cpython-312.pyc b/src/antibody_training_esm/datasets/__pycache__/default_paths.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37a8ac2dedadb8580f8c0718c014c518b97943e7 Binary files /dev/null and b/src/antibody_training_esm/datasets/__pycache__/default_paths.cpython-312.pyc differ diff --git a/src/antibody_training_esm/datasets/__pycache__/ginkgo.cpython-312.pyc b/src/antibody_training_esm/datasets/__pycache__/ginkgo.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e0ed2606c5762fa3ced1d59dd475f10d0dba4bf Binary files /dev/null and b/src/antibody_training_esm/datasets/__pycache__/ginkgo.cpython-312.pyc differ diff --git a/src/antibody_training_esm/datasets/__pycache__/harvey.cpython-312.pyc b/src/antibody_training_esm/datasets/__pycache__/harvey.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b4b6d9850ad53590a6e327a74d00403b0b6b832 Binary files /dev/null and b/src/antibody_training_esm/datasets/__pycache__/harvey.cpython-312.pyc differ diff --git a/src/antibody_training_esm/datasets/__pycache__/jain.cpython-312.pyc b/src/antibody_training_esm/datasets/__pycache__/jain.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..878ef63d2f8e918f8915d031f93f40e6974c63be Binary files /dev/null and b/src/antibody_training_esm/datasets/__pycache__/jain.cpython-312.pyc differ diff --git a/src/antibody_training_esm/datasets/__pycache__/shehata.cpython-312.pyc b/src/antibody_training_esm/datasets/__pycache__/shehata.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be1caf2e06d2ec1c4204dcd4d953d2a77cafb66b Binary files /dev/null and b/src/antibody_training_esm/datasets/__pycache__/shehata.cpython-312.pyc differ diff --git a/src/antibody_training_esm/datasets/base.py b/src/antibody_training_esm/datasets/base.py new file mode 100644 index 0000000000000000000000000000000000000000..58431e669fb41a22d15a44c2f8e6ca9409a37f43 --- /dev/null +++ b/src/antibody_training_esm/datasets/base.py @@ -0,0 +1,307 @@ +import logging +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, cast + +import pandas as pd +import pandera.pandas as pa +from pandera.errors import SchemaError + +from antibody_training_esm.datasets.mixins.annotation_mixin import AnnotationMixin +from antibody_training_esm.datasets.mixins.fragment_mixin import FragmentMixin +from antibody_training_esm.schemas.dataset import get_sequence_dataset_schema + + +class AntibodyDataset(ABC, AnnotationMixin, FragmentMixin): + """ + Abstract base class for antibody dataset preprocessing. + + This class defines the common interface that all dataset preprocessors must implement + and provides shared utility methods for common operations like sequence validation, + ANARCI annotation, and fragment generation. + + Design Principles: + - Single Responsibility: Each concrete class handles ONE dataset + - Open/Closed: New datasets extend this class without modifying it + - Dependency Inversion: High-level preprocessing depends on this abstraction + """ + + # Standard fragment types for full antibodies (VH + VL) + FULL_ANTIBODY_FRAGMENTS = [ + "VH_only", + "VL_only", + "VH+VL", + "H-CDR1", + "H-CDR2", + "H-CDR3", + "L-CDR1", + "L-CDR2", + "L-CDR3", + "H-CDRs", + "L-CDRs", + "All-CDRs", + "H-FWRs", + "L-FWRs", + "All-FWRs", + "Full", + ] + + # Standard fragment types for nanobodies (VHH only) + NANOBODY_FRAGMENTS = [ + "VHH_only", + "H-CDR1", + "H-CDR2", + "H-CDR3", + "H-CDRs", + "H-FWRs", + ] + + # Valid amino acid characters (20 standard + X for unknown/ambiguous) + # X is included for compatibility with ESM models which support ambiguous residues + VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWYX") + + def __init__( + self, + dataset_name: str, + output_dir: Path | None = None, + logger: logging.Logger | None = None, + ): + """ + Initialize dataset preprocessor. + + Args: + dataset_name: Name of the dataset (e.g., "jain", "harvey") + output_dir: Directory to write processed outputs + logger: Logger instance (creates default if None) + """ + self.dataset_name = dataset_name + self.output_dir = ( + Path(output_dir) if output_dir else Path(f"experiments/runs/{dataset_name}") + ) + self.logger = logger or self._create_default_logger() + + # Create output directory if it doesn't exist + self.output_dir.mkdir(parents=True, exist_ok=True) + + @classmethod + def get_schema(cls) -> pa.DataFrameSchema: + """ + Get the Pandera schema for this dataset. + Subclasses should override this method. + """ + return get_sequence_dataset_schema() + + @classmethod + def validate_dataframe(cls, df: pd.DataFrame) -> pd.DataFrame: + """ + Validate DataFrame against Pandera schema. + + Args: + df: Raw DataFrame from CSV + + Returns: + Validated DataFrame (possibly coerced types) + + Raises: + ValueError: If validation fails (wraps SchemaError) + """ + try: + import pandera.backends.pandas # noqa: F401 + + # Use lazy=False to fail fast (default behavior) + # Note: SequenceDatasetSchema uses lazy=False in its definition implicitly + validated_df = cast(pd.DataFrame, cls.get_schema().validate(df, lazy=False)) + return validated_df + except SchemaError as e: + # Enhance error message with dataset context + raise ValueError( + f"Schema validation failed for {cls.__name__}:\n{e}" + ) from e + + def _create_default_logger(self) -> logging.Logger: + """Create a default logger if none provided""" + logger = logging.getLogger( + f"antibody_training_esm.datasets.{self.dataset_name}" + ) + if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + return logger + + # ========== ABSTRACT METHODS (MUST BE IMPLEMENTED) ========== + + @abstractmethod + def load_data(self, **kwargs: Any) -> pd.DataFrame: + """ + Load raw dataset from source files. + + This method must be implemented by each dataset since data loading + is dataset-specific (Excel, CSV, FASTA, etc.). + + Returns: + DataFrame with columns: id, VH_sequence, VL_sequence (optional), label + """ + pass + + @abstractmethod + def get_fragment_types(self) -> list[str]: + """ + Return the list of fragment types for this dataset. + + Most datasets use FULL_ANTIBODY_FRAGMENTS (16 types). + Nanobody datasets (like Harvey) use NANOBODY_FRAGMENTS (6 types). + + Returns: + List of fragment type names + """ + pass + + # ========== COMMON UTILITY METHODS ========== + + def sanitize_sequence(self, sequence: str) -> str: + """ + Clean and validate a protein sequence. + + Operations: + - Remove gap characters (-) + - Remove whitespace + - Convert to uppercase + - Validate amino acids + + Args: + sequence: Raw protein sequence + + Returns: + Cleaned sequence + + Raises: + ValueError: If sequence contains invalid characters + """ + if not sequence or not isinstance(sequence, str): + raise ValueError("Sequence must be a non-empty string") + + # Remove gaps and whitespace + sequence = sequence.replace("-", "").replace(" ", "").upper() + + # Validate amino acids + invalid_chars = set(sequence) - self.VALID_AMINO_ACIDS + if invalid_chars: + raise ValueError( + f"Sequence contains invalid amino acids: {invalid_chars}. " + f"Valid amino acids: {self.VALID_AMINO_ACIDS}" + ) + + return sequence + + def validate_sequences(self, df: pd.DataFrame) -> dict[str, Any]: + """ + Validate all sequences in a DataFrame. + + Checks: + - Valid amino acids + - Sequence lengths + - Missing sequences + + Args: + df: DataFrame with VH_sequence and optionally VL_sequence columns + + Returns: + Dictionary with validation statistics + """ + stats = { + "total_sequences": len(df), + "valid_sequences": 0, + "invalid_sequences": 0, + "missing_vh": 0, + "missing_vl": 0, + "length_stats": {}, + } + + # Check VH sequences + if "VH_sequence" in df.columns: + missing_vh = int(df["VH_sequence"].isna().sum()) + stats["missing_vh"] = missing_vh + valid_vh = df["VH_sequence"].notna() + + if valid_vh.any(): + vh_lengths = df.loc[valid_vh, "VH_sequence"].str.len() + # Cast to help mypy understand the type + length_stats = cast( + dict[str, dict[str, int | float]], stats["length_stats"] + ) + length_stats["VH"] = { + "min": int(vh_lengths.min()), + "max": int(vh_lengths.max()), + "mean": float(vh_lengths.mean()), + } + + # Check VL sequences (if present) + if "VL_sequence" in df.columns: + missing_vl = int(df["VL_sequence"].isna().sum()) + stats["missing_vl"] = missing_vl + valid_vl = df["VL_sequence"].notna() + + if valid_vl.any(): + vl_lengths = df.loc[valid_vl, "VL_sequence"].str.len() + # Cast to help mypy understand the type + length_stats = cast( + dict[str, dict[str, int | float]], stats["length_stats"] + ) + length_stats["VL"] = { + "min": int(vl_lengths.min()), + "max": int(vl_lengths.max()), + "mean": float(vl_lengths.mean()), + } + + # Use explicit variable for type safety + missing_vh_count = cast(int, stats["missing_vh"]) + stats["valid_sequences"] = len(df) - missing_vh_count + stats["invalid_sequences"] = missing_vh_count + + return stats + + def print_statistics(self, df: pd.DataFrame, stage: str = "Final") -> None: + """ + Print dataset statistics to logger. + + Args: + df: DataFrame with processed data + stage: Stage name for logging (e.g., "Raw", "Filtered", "Final") + """ + self.logger.info(f"\n{'=' * 60}") + self.logger.info(f"{stage} Dataset Statistics - {self.dataset_name}") + self.logger.info(f"{'=' * 60}") + + # Basic counts + self.logger.info(f"Total sequences: {len(df)}") + + # Label distribution + if "label" in df.columns: + label_counts = df["label"].value_counts() + self.logger.info("\nLabel distribution:") + for label, count in label_counts.items(): + percentage = (count / len(df)) * 100 + label_name = "Non-specific" if label == 1 else "Specific" + self.logger.info( + f" {label_name} (label={label}): {count} ({percentage:.1f}%)" + ) + + # Sequence validation stats + val_stats = self.validate_sequences(df) + self.logger.info("\nSequence validation:") + self.logger.info(f" Valid sequences: {val_stats['valid_sequences']}") + self.logger.info(f" Invalid sequences: {val_stats['invalid_sequences']}") + + if val_stats["length_stats"]: + self.logger.info("\nSequence length statistics:") + for chain, stats in val_stats["length_stats"].items(): + self.logger.info( + f" {chain}: min={stats['min']}, max={stats['max']}, mean={stats['mean']:.1f}" + ) + + self.logger.info(f"{'=' * 60}\n") diff --git a/src/antibody_training_esm/datasets/boughter.py b/src/antibody_training_esm/datasets/boughter.py new file mode 100644 index 0000000000000000000000000000000000000000..4708877b509537c2f4dd3a6be674606fce8be6ae --- /dev/null +++ b/src/antibody_training_esm/datasets/boughter.py @@ -0,0 +1,307 @@ +""" +Boughter Dataset Loader + +Loads preprocessed Boughter mouse antibody dataset. + +IMPORTANT: This module is for LOADING preprocessed data, not for running +the preprocessing pipeline. The preprocessing scripts that CREATE the data +are in: preprocessing/boughter/stage2_stage3_annotation_qc.py + +Dataset characteristics: +- Full antibodies (VH + VL) +- Mouse antibodies from 6 subsets (flu, hiv, gut, mouse IgA) +- DNA sequences requiring translation to protein +- Novo flagging strategy (0/1-3/4+ flags) +- 3-stage quality control pipeline +- 16 fragment types (full antibody) + +Processing Pipeline: + Stage 1: DNA translation (FASTA → protein sequences) + Stage 2: ANARCI annotation (riot_na) + Stage 3: Post-annotation QC (filter X in CDRs, empty CDRs) + +Source: +- data/train/boughter/raw/ (multiple subsets) +- Sequences in DNA format requiring translation + +Reference: +- Boughter et al., "Biochemical patterns of antibody polyreactivity revealed through a bioinformatics-based analysis of CDR loops" +""" + +import logging +from pathlib import Path +from typing import Any, NoReturn + +import pandas as pd +import pandera.pandas as pa + +from antibody_training_esm.schemas.dataset import get_boughter_schema +from antibody_training_esm.settings import settings + +from .base import AntibodyDataset + +BOUGHTER_ANNOTATED_DIR = settings.BOUGHTER_ANNOTATED_DIR +BOUGHTER_PROCESSED_CSV = settings.BOUGHTER_PROCESSED_CSV + + +class BoughterDataset(AntibodyDataset): + """ + Loader for Boughter mouse antibody dataset. + + This class provides an interface to LOAD preprocessed Boughter dataset files. + It does NOT run the preprocessing pipeline - use preprocessing/boughter/stage2_stage3_annotation_qc.py for that. + + The Boughter dataset originally requires DNA translation before standard preprocessing. + Sequences are provided as DNA in FASTA format and must be translated + to protein sequences using a hybrid translation strategy (done by preprocessing scripts). + """ + + # Novo flagging strategy + FLAG_SPECIFIC = 0 # 0 flags = specific (include in training) + FLAG_MILD = [1, 2, 3] # 1-3 flags = mild (EXCLUDE from training) + FLAG_NONSPECIFIC = [4, 5, 6, 7] # 4+ flags = non-specific (include in training) + + # Dataset subsets + SUBSETS = ["flu", "hiv_nat", "hiv_cntrl", "hiv_plos", "gut_hiv", "mouse_iga"] + + @classmethod + def get_schema(cls) -> pa.DataFrameSchema: + return get_boughter_schema() + + def __init__( + self, output_dir: Path | None = None, logger: logging.Logger | None = None + ): + """ + Initialize Boughter dataset loader. + + Args: + output_dir: Directory containing preprocessed fragment files + logger: Logger instance + """ + super().__init__( + dataset_name="boughter", + output_dir=output_dir or BOUGHTER_ANNOTATED_DIR, + logger=logger, + ) + + def get_fragment_types(self) -> list[str]: + """ + Return full antibody fragment types. + + Boughter contains VH + VL sequences, so we generate all 16 fragment types. + + Returns: + List of 16 full antibody fragment types + """ + return self.FULL_ANTIBODY_FRAGMENTS + + def load_data( + self, + processed_csv: str | Path | None = None, + subset: str | None = None, + include_mild: bool = False, + **_: Any, + ) -> pd.DataFrame: + """ + Load Boughter dataset from processed CSV. + + Note: This assumes DNA translation has already been performed. + For DNA translation from FASTA files, use the preprocessing scripts + in preprocessing/boughter/ + + Args: + processed_csv: Path to processed CSV with protein sequences + subset: Specific subset to load (flu, hiv_nat, etc.) or None for all + include_mild: If True, include mild (1-3 flags). Default False. + + Returns: + DataFrame with columns: id, VH_sequence, VL_sequence, label, flags, include_in_training + + Raises: + FileNotFoundError: If processed CSV not found + """ + # Default path + if processed_csv is None: + processed_csv = BOUGHTER_PROCESSED_CSV + + csv_file = Path(processed_csv) + if not csv_file.exists(): + raise FileNotFoundError( + f"Boughter processed CSV not found: {csv_file}\n" + f"Please run DNA translation preprocessing first:\n" + f" python preprocessing/boughter/stage1_dna_translation.py" + ) + + # Load data + self.logger.info(f"Reading Boughter dataset from {csv_file}...") + df = pd.read_csv(csv_file) + self.logger.info(f" Loaded {len(df)} sequences") + + # Filter by subset if specified + if subset is not None: + if subset not in self.SUBSETS: + raise ValueError(f"Unknown subset: {subset}. Valid: {self.SUBSETS}") + df = df[df["subset"] == subset].copy() + self.logger.info(f" Filtered to subset '{subset}': {len(df)} sequences") + + # Apply Novo flagging strategy (only if flags column exists) + # Pre-filtered training files (e.g., *_training.csv) don't have flags column + if not include_mild: + # Check if flags column exists (may be 'num_flags' or 'flags') + has_flags = "num_flags" in df.columns or "flags" in df.columns + + if has_flags: + # Exclude mild (1-3 flags) per Novo Nordisk methodology + flag_col = "num_flags" if "num_flags" in df.columns else "flags" + df["include_in_training"] = ~df[flag_col].isin(self.FLAG_MILD) + df_training = df[df["include_in_training"]].copy() + + excluded = len(df) - len(df_training) + self.logger.info("\nNovo flagging strategy:") + self.logger.info( + f" Excluded {excluded} sequences with mild flags (1-3)" + ) + self.logger.info(f" Training set: {len(df_training)} sequences") + + df = df_training + else: + # File is pre-filtered (training subset) - no flags column + self.logger.info( + " No flags column found - assuming pre-filtered training data" + ) + + # Standardize column names + column_mapping = { + "heavy_seq": "VH_sequence", + "light_seq": "VL_sequence", + } + if "heavy_seq" in df.columns: + df = df.rename(columns=column_mapping) + + # Create binary labels from flags + # 0 flags → specific (label=0) + # 4+ flags → non-specific (label=1) + flag_col = "num_flags" if "num_flags" in df.columns else "flags" + if flag_col in df.columns: + df["label"] = (df[flag_col] >= 4).astype(int) + + # Create 'sequence' column for schema validation (use VH) + if "sequence" not in df.columns and "VH_sequence" in df.columns: + df["sequence"] = df["VH_sequence"] + + # Validate with Pandera + df = self.validate_dataframe(df) + + self.logger.info("\nLabel distribution:") + label_counts = df["label"].value_counts().sort_index() + for label, count in label_counts.items(): + label_name = "Specific" if label == 0 else "Non-specific" + percentage = (count / len(df)) * 100 + self.logger.info( + f" {label_name} (label={label}): {count} ({percentage:.1f}%)" + ) + + return df + + def translate_dna_to_protein(self, dna_sequence: str) -> NoReturn: # noqa: ARG002 + """ + This method is NOT IMPLEMENTED and will always raise an error. + + DNA translation logic belongs in the preprocessing scripts, not in + dataset loader classes. Loaders are for LOADING preprocessed data, + not for creating it. + + For DNA translation, use: + preprocessing/boughter/stage1_dna_translation.py + + Args: + dna_sequence: DNA sequence string (unused - always raises) + + Raises: + NotImplementedError: Always - this method intentionally does nothing + """ + raise NotImplementedError( + "DNA translation is not implemented in dataset loader classes.\n" + "Dataset loaders are for LOADING preprocessed data, not creating it.\n" + "\n" + "For DNA translation, use the preprocessing script:\n" + " python preprocessing/boughter/stage1_dna_translation.py\n" + "\n" + "This script implements the full hybrid translation strategy:\n" + " 1. Direct V-domain translation (pre-trimmed sequences)\n" + " 2. ATG-based translation (full-length with signal peptide)\n" + " 3. V-domain motif detection (EVQL, QVQL, etc.)" + ) + + def filter_quality_issues(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Stage 3 QC: Filter sequences with quality issues. + + Removes: + - Sequences with X in CDRs (ambiguous amino acids) + - Sequences with empty CDRs + - Invalid annotations + + Args: + df: Annotated DataFrame + + Returns: + Filtered DataFrame + """ + initial_count = len(df) + + # Filter X in CDRs + cdr_cols = [ + col for col in df.columns if "CDR" in col and ("VH_" in col or "VL_" in col) + ] + + if cdr_cols: + for col in cdr_cols: + if col in df.columns: + df = df[~df[col].str.contains("X", na=False)].copy() + + # Filter empty CDRs + for col in cdr_cols: + if col in df.columns: + df = df[df[col].str.len() > 0].copy() + + filtered_count = initial_count - len(df) + + if filtered_count > 0: + self.logger.info(f"\nStage 3 QC filtered {filtered_count} sequences:") + self.logger.info(f" Remaining: {len(df)} sequences") + + return df + + +# ========== CONVENIENCE FUNCTIONS FOR LOADING DATA ========== + + +def load_boughter_data( + processed_csv: str | None = None, + subset: str | None = None, + include_mild: bool = False, +) -> pd.DataFrame: + """ + Convenience function to load preprocessed Boughter dataset. + + IMPORTANT: This loads PREPROCESSED data. To preprocess raw data, use: + preprocessing/boughter/stage2_stage3_annotation_qc.py + + Args: + processed_csv: Path to processed CSV with protein sequences + subset: Specific subset to load or None for all + include_mild: If True, include mild (1-3 flags) + + Returns: + DataFrame with preprocessed data + + Example: + >>> from antibody_training_esm.datasets.boughter import load_boughter_data + >>> df = load_boughter_data(include_mild=False) # Novo flagging + >>> print(f"Loaded {len(df)} sequences") + """ + dataset = BoughterDataset() + return dataset.load_data( + processed_csv=processed_csv, subset=subset, include_mild=include_mild + ) diff --git a/src/antibody_training_esm/datasets/harvey.py b/src/antibody_training_esm/datasets/harvey.py new file mode 100644 index 0000000000000000000000000000000000000000..ec8f30e679b03063b20bbe2abe297950523122a9 --- /dev/null +++ b/src/antibody_training_esm/datasets/harvey.py @@ -0,0 +1,253 @@ +""" +Harvey Dataset Loader + +Loads preprocessed Harvey nanobody polyreactivity dataset. + +IMPORTANT: This module is for LOADING preprocessed data, not for running +the preprocessing pipeline. The preprocessing scripts that CREATE the data +are in: preprocessing/harvey/step2_extract_fragments.py + +Dataset characteristics: +- Nanobodies (VHH only, no light chain) +- High-throughput screen data (141,474 sequences) +- Binary classification: high/low polyreactivity +- IMGT-numbered positions in raw data +- 6 fragment types (VHH-specific) + +Source: +- data/test/harvey/raw/high_polyreactivity_high_throughput.csv +- data/test/harvey/raw/low_polyreactivity_high_throughput.csv + +Reference: +- Harvey et al., Engineering highly expressed antibodies for nanobody discovery platforms +""" + +import logging +from pathlib import Path +from typing import Any + +import pandas as pd +import pandera.pandas as pa + +from antibody_training_esm.schemas.dataset import get_harvey_schema +from antibody_training_esm.settings import settings + +from .base import AntibodyDataset + +HARVEY_HIGH_POLY_CSV = settings.HARVEY_HIGH_POLY_CSV +HARVEY_LOW_POLY_CSV = settings.HARVEY_LOW_POLY_CSV +HARVEY_OUTPUT_DIR = settings.HARVEY_OUTPUT_DIR + + +class HarveyDataset(AntibodyDataset): + """ + Loader for Harvey nanobody dataset. + + This class provides an interface to LOAD preprocessed Harvey dataset files. + It does NOT run the preprocessing pipeline - use preprocessing/harvey/step2_extract_fragments.py for that. + + The Harvey dataset contains VHH sequences (heavy chain only, no light chain) from + a high-throughput polyreactivity screen. Sequences are provided with IMGT + numbering and pre-extracted CDR regions. + """ + + def __init__( + self, output_dir: Path | None = None, logger: logging.Logger | None = None + ): + """ + Initialize Harvey dataset loader. + + Args: + output_dir: Directory containing preprocessed fragment files + logger: Logger instance + """ + super().__init__( + dataset_name="harvey", + output_dir=output_dir or HARVEY_OUTPUT_DIR, + logger=logger, + ) + + @classmethod + def get_schema(cls) -> pa.DataFrameSchema: + return get_harvey_schema() + + def get_fragment_types(self) -> list[str]: + """ + Return nanobody-specific fragment types. + + Harvey contains VHH sequences only (no light chain), so we generate + 6 fragment types instead of the full 16. + + Returns: + List of 6 nanobody fragment types + """ + return self.NANOBODY_FRAGMENTS + + def extract_sequence_from_imgt(self, row: pd.Series, imgt_cols: list[str]) -> str: + """ + Extract full sequence from IMGT-numbered position columns. + + The Harvey raw data contains columns "1" through "128" representing + IMGT numbering positions. This method concatenates non-gap positions + to reconstruct the full sequence. + + Args: + row: DataFrame row with IMGT position columns + imgt_cols: List of column names ['1', '2', ..., '128'] + + Returns: + Full sequence string with gaps removed + """ + positions = [] + for col in imgt_cols: + if col in row and pd.notna(row[col]) and row[col] != "-": + positions.append(row[col]) + return "".join(positions) + + def load_data( + self, + high_csv_path: str | Path | None = None, + low_csv_path: str | Path | None = None, + **_: Any, + ) -> pd.DataFrame: + """ + Load Harvey dataset from high/low polyreactivity CSV files. + + Args: + high_csv_path: Path to high_polyreactivity_high_throughput.csv + low_csv_path: Path to low_polyreactivity_high_throughput.csv + + Returns: + DataFrame with columns: id, VH_sequence, label + + Raises: + FileNotFoundError: If input CSV files not found + """ + # Default paths + if high_csv_path is None: + high_csv_path = HARVEY_HIGH_POLY_CSV + if low_csv_path is None: + low_csv_path = HARVEY_LOW_POLY_CSV + + # Validate paths + high_csv = Path(high_csv_path) + low_csv = Path(low_csv_path) + + if not high_csv.exists(): + raise FileNotFoundError( + f"High polyreactivity CSV not found: {high_csv}\n" + f"Please ensure raw files are in data/test/harvey/raw/" + ) + + if not low_csv.exists(): + raise FileNotFoundError( + f"Low polyreactivity CSV not found: {low_csv}\n" + f"Please ensure raw files are in data/test/harvey/raw/" + ) + + # Load datasets + self.logger.info(f"Reading high polyreactivity data from {high_csv}...") + df_high = pd.read_csv(high_csv) + if len(df_high) == 0: + raise ValueError( + f"Loaded dataset is empty: {high_csv}\n" + "The CSV file may be corrupted or truncated. " + "Please check the file or re-run preprocessing." + ) + self.logger.info(f" Loaded {len(df_high)} high polyreactivity sequences") + + self.logger.info(f"Reading low polyreactivity data from {low_csv}...") + df_low = pd.read_csv(low_csv) + if len(df_low) == 0: + raise ValueError( + f"Loaded dataset is empty: {low_csv}\n" + "The CSV file may be corrupted or truncated. " + "Please check the file or re-run preprocessing." + ) + self.logger.info(f" Loaded {len(df_low)} low polyreactivity sequences") + + # IMGT position columns (1-128) + imgt_cols = [str(i) for i in range(1, 129)] + + # Extract full sequences from IMGT positions + self.logger.info("Extracting sequences from IMGT positions...") + df_high["VH_sequence"] = df_high.apply( + lambda row: self.extract_sequence_from_imgt(row, imgt_cols), axis=1 + ) + df_low["VH_sequence"] = df_low.apply( + lambda row: self.extract_sequence_from_imgt(row, imgt_cols), axis=1 + ) + + # Add binary labels + df_high["label"] = 1 # high polyreactivity = non-specific + df_low["label"] = 0 # low polyreactivity = specific + + # Combine datasets + self.logger.info("Combining high and low polyreactivity datasets...") + df_combined = pd.concat([df_high, df_low], ignore_index=True) + + # Create sequence IDs + df_combined["id"] = [f"harvey_{i:06d}" for i in range(len(df_combined))] + + # Select standardized columns + df_output = df_combined[["id", "VH_sequence", "label"]].copy() + + # Filter out empty sequences + empty_mask = df_output["VH_sequence"].str.len() == 0 + if empty_mask.any(): + n_empty = empty_mask.sum() + self.logger.warning(f"Removing {n_empty} sequences with zero length") + df_output = df_output[~empty_mask].reset_index(drop=True) + + # Create 'sequence' column for schema validation (use VH) + if "sequence" not in df_output.columns and "VH_sequence" in df_output.columns: + df_output["sequence"] = df_output["VH_sequence"] + + # Validate with Pandera + df_output = self.validate_dataframe(df_output) + + self.logger.info(f"Combined dataset: {len(df_output)} sequences") + self.logger.info( + f" High polyreactivity (label=1): {(df_output['label'] == 1).sum()}" + ) + self.logger.info( + f" Low polyreactivity (label=0): {(df_output['label'] == 0).sum()}" + ) + + # Sequence length stats + seq_lengths = df_output["VH_sequence"].str.len() + self.logger.info( + f"Sequence length range: {seq_lengths.min()}-{seq_lengths.max()} aa " + f"(mean: {seq_lengths.mean():.1f})" + ) + + return df_output + + +# ========== CONVENIENCE FUNCTIONS FOR LOADING DATA ========== + + +def load_harvey_data( + high_csv: str | None = None, + low_csv: str | None = None, +) -> pd.DataFrame: + """ + Convenience function to load preprocessed Harvey dataset. + + IMPORTANT: This loads PREPROCESSED data. To preprocess raw data, use: + preprocessing/harvey/step2_extract_fragments.py + + Args: + high_csv: Path to high polyreactivity CSV + low_csv: Path to low polyreactivity CSV + + Returns: + DataFrame with preprocessed data + + Example: + >>> from antibody_training_esm.datasets.harvey import load_harvey_data + >>> df = load_harvey_data() + >>> print(f"Loaded {len(df)} sequences") + """ + dataset = HarveyDataset() + return dataset.load_data(high_csv_path=high_csv, low_csv_path=low_csv) diff --git a/src/antibody_training_esm/datasets/jain.py b/src/antibody_training_esm/datasets/jain.py new file mode 100644 index 0000000000000000000000000000000000000000..de19ad2b0f23bbd284621df0c55ae67f5128cc9b --- /dev/null +++ b/src/antibody_training_esm/datasets/jain.py @@ -0,0 +1,406 @@ +""" +Jain Dataset Loader + +Loads preprocessed Jain 2017 therapeutic antibody dataset. + +IMPORTANT: This module is for LOADING preprocessed data, not for running +the preprocessing pipeline. The preprocessing scripts that CREATE the data +are in: preprocessing/jain/step2_preprocess_p5e_s2.py + +Dataset characteristics: +- Full antibodies (VH + VL) +- 137 FDA-approved/clinical-stage therapeutics +- Multi-stage filtering with biophysical parameters +- Novo Nordisk parity requirements (86 antibodies, [[40, 19], [10, 17]]) +- 16 fragment types (full antibody) + +Processing Pipeline: + 137 antibodies (FULL) + ↓ Remove ELISA 1-3 (mild aggregators) + 116 antibodies (SSOT) + ↓ Reclassify 5 spec→nonspec (PSR>0.4, Tm, clinical) + 89 spec / 27 nonspec + ↓ Remove 30 by PSR/AC-SINS ranking + 86 antibodies (59 spec / 27 nonspec) - NOVO PARITY + +Source: +- data/test/jain/processed/jain_with_private_elisa_FULL.csv +- data/test/jain/processed/jain_sd03.csv (biophysical data) + +Reference: +- Jain et al. (2017), "Biophysical properties of the clinical-stage antibody landscape" +""" + +import logging +from pathlib import Path +from typing import Any, cast + +import pandas as pd +import pandera.pandas as pa +from pandera.errors import SchemaError + +from antibody_training_esm.schemas.dataset import ( + get_jain_preprocessing_schema, + get_jain_schema, +) +from antibody_training_esm.settings import settings + +from .base import AntibodyDataset + +JAIN_FULL_CSV = settings.JAIN_FULL_CSV +JAIN_OUTPUT_DIR = settings.JAIN_OUTPUT_DIR +JAIN_SD03_CSV = settings.JAIN_SD03_CSV + + +# Novo Nordisk Parity Constants (from Sakhnini et al. 2025) +# Paper benchmark: 86 antibodies with [[40, 19], [10, 17]] confusion matrix +NOVO_PARITY_SPECIFIC_COUNT = 59 # Specific antibodies in parity set +NOVO_PARITY_NONSPECIFIC_COUNT = 27 # Non-specific antibodies in parity set +NOVO_PARITY_TOTAL = 86 # Total parity set size (59 + 27) +NOVO_PARITY_EXPECTED_CORRECT = 57 # Expected correct predictions (40 + 17) +NOVO_PARITY_ACCURACY = 66.28 # Expected accuracy (57/86 = 0.6628) + + +class JainDataset(AntibodyDataset): + """ + Loader for Jain therapeutic antibody dataset. + + This class provides an interface to LOAD preprocessed Jain dataset files. + It does NOT run the preprocessing pipeline - use preprocessing/jain/step2_preprocess_p5e_s2.py for that. + + The Jain dataset contains FDA-approved and clinical-stage therapeutic antibodies + with complex multi-stage filtering to achieve Novo Nordisk parity. + """ + + # P5e-S2 Method Constants (Novo Nordisk parity) + PSR_THRESHOLD = 0.4 + + # Reclassification tiers + TIER_A_PSR = ["bimagrumab", "bavituximab", "ganitumab"] # PSR >0.4 + TIER_B_EXTREME_TM = "eldelumab" # Extreme Tm outlier (59.50°C) + TIER_C_CLINICAL = "infliximab" # 61% ADA rate + chimeric + + @classmethod + def get_schema(cls) -> pa.DataFrameSchema: + return get_jain_schema() + + def __init__( + self, output_dir: Path | None = None, logger: logging.Logger | None = None + ): + """ + Initialize Jain dataset loader. + + Args: + output_dir: Directory containing preprocessed fragment files + logger: Logger instance + """ + super().__init__( + dataset_name="jain", + output_dir=output_dir or JAIN_OUTPUT_DIR, + logger=logger, + ) + + def get_fragment_types(self) -> list[str]: + """ + Return full antibody fragment types. + + Jain contains VH + VL sequences, so we generate all 16 fragment types. + + Returns: + List of 16 full antibody fragment types + """ + return self.FULL_ANTIBODY_FRAGMENTS + + def load_data( + self, + full_csv_path: str | Path | None = None, + sd03_csv_path: str | Path | None = None, + stage: str = "full", + **_: Any, + ) -> pd.DataFrame: + """ + Load Jain dataset from CSV files. + + Args: + full_csv_path: Path to jain_with_private_elisa_FULL.csv (137 antibodies) + sd03_csv_path: Path to jain_sd03.csv (biophysical data) + stage: Which processing stage to load: + "full" - 137 antibodies (raw) + "ssot" - 116 antibodies (ELISA-filtered) + "parity" - 86 antibodies (Novo parity) + + Returns: + DataFrame with columns: id, VH_sequence, VL_sequence, label, elisa_flags, psr, ac_sins, hic, fab_tm + + Raises: + FileNotFoundError: If input CSV files not found + ValueError: If stage is invalid + """ + # Validate stage + valid_stages = {"full", "ssot", "parity"} + if stage not in valid_stages: + raise ValueError(f"Invalid stage '{stage}'. Must be one of: {valid_stages}") + + # Default paths + if full_csv_path is None: + full_csv_path = JAIN_FULL_CSV + if sd03_csv_path is None: + sd03_csv_path = JAIN_SD03_CSV + + full_csv = Path(full_csv_path) + sd03_csv = Path(sd03_csv_path) + + if not full_csv.exists(): + raise FileNotFoundError( + f"Jain FULL CSV not found: {full_csv}\n" + f"Please ensure source data is in data/test/jain/processed/" + ) + + # Load main dataset + self.logger.info(f"Reading Jain FULL dataset from {full_csv}...") + df = pd.read_csv(full_csv) + + # Validate dataset is not empty + if len(df) == 0: + raise ValueError( + f"Loaded dataset is empty: {full_csv}\n" + "The CSV file may be corrupted or truncated. " + "Please check the file or re-run preprocessing." + ) + + self.logger.info(f" Loaded {len(df)} antibodies") + self.logger.info(f" Specific: {(df['label'] == 0).sum()}") + self.logger.info(f" Non-specific: {(df['label'] == 1).sum()}") + + # Standardize column names + column_mapping = { + "heavy_seq": "VH_sequence", + "light_seq": "VL_sequence", + "vh_sequence": "VH_sequence", # Support VH-only files + "vl_sequence": "VL_sequence", # Support VL-only files + } + df = df.rename(columns=column_mapping) + + # Load biophysical data if available + if sd03_csv.exists(): + self.logger.info(f"Loading biophysical data from {sd03_csv}...") + sd03 = pd.read_csv(sd03_csv) + + # Merge biophysical columns + df = df.merge( + sd03[ + [ + "Name", + "Poly-Specificity Reagent (PSR) SMP Score (0-1)", + "Affinity-Capture Self-Interaction Nanoparticle Spectroscopy (AC-SINS) āˆ†Ī»max (nm) Average", + "HIC Retention Time (Min)a", + "Fab Tm by DSF (°C)", + ] + ], + left_on="id", + right_on="Name", + how="left", + ) + + # Rename for easier handling + df = df.rename( + columns={ + "Poly-Specificity Reagent (PSR) SMP Score (0-1)": "psr", + "Affinity-Capture Self-Interaction Nanoparticle Spectroscopy (AC-SINS) āˆ†Ī»max (nm) Average": "ac_sins", + "HIC Retention Time (Min)a": "hic", + "Fab Tm by DSF (°C)": "fab_tm", + } + ) + df = df.drop(columns=["Name"]) + + self.logger.info(" Biophysical data merged") + self.logger.info(f" Missing PSR: {df['psr'].isna().sum()}") + self.logger.info(f" Missing AC-SINS: {df['ac_sins'].isna().sum()}") + + # Apply stage-specific filtering + if stage == "ssot": + df = self.filter_elisa_1to3(df) + elif stage == "parity": + df = self.filter_elisa_1to3(df) + df = self.reclassify_5_antibodies(df) + df = self.remove_30_by_psr_acsins(df) + + # Create 'sequence' column for schema validation (use VH) + if "sequence" not in df.columns and "VH_sequence" in df.columns: + df["sequence"] = df["VH_sequence"] + + # Validate with Pandera + if stage == "full": + # Use preprocessing schema (allows NaN labels) for full stage + try: + validated_df = get_jain_preprocessing_schema().validate(df, lazy=False) + df = cast(pd.DataFrame, validated_df) + except SchemaError as e: + raise ValueError( + f"Schema validation failed for JainDataset (stage='full'):\n{e}" + ) from e + else: + # Use strict schema (no NaN labels) for filtered stages + df = self.validate_dataframe(df) + + return df + + def filter_elisa_1to3(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Remove ELISA 1-3 (mild aggregators) → 116 antibodies (SSOT). + + ELISA flags 1-3 indicate mild to moderate aggregation in ELISA assays. + These are filtered out as they don't represent strong enough + polyreactivity signal for training. + + Args: + df: Full dataset (137 antibodies) + + Returns: + Filtered dataset (116 antibodies) + """ + initial_count = len(df) + df_filtered = df[~df["elisa_flags"].isin([1, 2, 3])].copy() + removed_count = initial_count - len(df_filtered) + + self.logger.info("\nFiltering ELISA 1-3 (mild aggregators):") + self.logger.info(f" Removed: {removed_count} antibodies") + self.logger.info(f" Remaining: {len(df_filtered)} antibodies") + + return df_filtered + + def reclassify_5_antibodies(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Reclassify 5 specific → non-specific. + + Tier A (PSR-based, 3 antibodies): + - bimagrumab (PSR=0.697) + - bavituximab (PSR=0.557) + - ganitumab (PSR=0.553) + All have ELISA=0 but PSR >0.4, indicating polyreactivity + + Tier B (Multi-metric, 1 antibody): + - eldelumab (Tm=59.50°C, extreme thermal instability outlier) + + Tier C (Clinical, 1 antibody): + - infliximab (61% ADA rate in NEJM study + chimeric + aggregation) + + Args: + df: 116-antibody dataset + + Returns: + Dataset with 5 antibodies reclassified (89 spec, 27 nonspec) + """ + df = df.copy() + df["label_original"] = df["label"] + df["reclassified"] = False + df["reclassification_reason"] = "" + + # Tier A: PSR >0.4 + for ab_id in self.TIER_A_PSR: + idx = df[df["id"] == ab_id].index + if len(idx) > 0: + df.loc[idx, "label"] = 1 + df.loc[idx, "reclassified"] = True + df.loc[idx, "reclassification_reason"] = "Tier A: PSR >0.4" + + # Tier B: Extreme Tm + idx = df[df["id"] == self.TIER_B_EXTREME_TM].index + if len(idx) > 0: + df.loc[idx, "label"] = 1 + df.loc[idx, "reclassified"] = True + df.loc[idx, "reclassification_reason"] = "Tier B: Extreme Tm" + + # Tier C: Clinical evidence + idx = df[df["id"] == self.TIER_C_CLINICAL].index + if len(idx) > 0: + df.loc[idx, "label"] = 1 + df.loc[idx, "reclassified"] = True + df.loc[idx, "reclassification_reason"] = "Tier C: Clinical (61% ADA)" + + spec_count = (df["label"] == 0).sum() + nonspec_count = (df["label"] == 1).sum() + + self.logger.info("\nReclassified 5 antibodies:") + self.logger.info(f" Specific: {spec_count} (expected 89)") + self.logger.info(f" Non-specific: {nonspec_count} (expected 27)") + + return df + + def remove_30_by_psr_acsins(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Remove 30 specific antibodies by PSR primary, AC-SINS tiebreaker. + + Removal strategy: + 1. Sort specific antibodies by PSR descending (primary) + 2. For PSR=0 antibodies, use AC-SINS descending (tiebreaker) + 3. Remove top 30 + + Args: + df: Dataset with 89 specific + 27 non-specific = 116 total + + Returns: + Final 86-antibody dataset (59 spec + 27 nonspec) + """ + # Get remaining specific and non-specific antibodies + specific = df[df["label"] == 0].copy() + nonspecific = df[df["label"] == 1].copy() + + # Sort by PSR (descending), then AC-SINS (descending), then id (alphabetical) + specific_sorted = specific.sort_values( + by=["psr", "ac_sins", "id"], ascending=[False, False, True] + ) + + # Keep bottom 59 specific + all 27 non-specific + specific_keep = specific_sorted.tail(NOVO_PARITY_SPECIFIC_COUNT) + df_86 = pd.concat([specific_keep, nonspecific], ignore_index=True) + + # Sort by id for consistency + df_86 = df_86.sort_values("id").reset_index(drop=True) + + spec_count = (df_86["label"] == 0).sum() + nonspec_count = (df_86["label"] == 1).sum() + + self.logger.info("\nRemoved 30 specific by PSR/AC-SINS:") + self.logger.info(f" Final: {len(df_86)} antibodies") + self.logger.info( + f" Specific: {spec_count} (expected {NOVO_PARITY_SPECIFIC_COUNT})" + ) + self.logger.info( + f" Non-specific: {nonspec_count} (expected {NOVO_PARITY_NONSPECIFIC_COUNT})" + ) + + return df_86 + + +# ========== CONVENIENCE FUNCTIONS FOR LOADING DATA ========== + + +def load_jain_data( + full_csv: str | None = None, + sd03_csv: str | None = None, + stage: str = "parity", +) -> pd.DataFrame: + """ + Convenience function to load preprocessed Jain dataset. + + IMPORTANT: This loads PREPROCESSED data. To preprocess raw data, use: + preprocessing/jain/step2_preprocess_p5e_s2.py + + Args: + full_csv: Path to jain_with_private_elisa_FULL.csv + sd03_csv: Path to jain_sd03.csv (biophysical data) + stage: Processing stage ("full", "ssot", or "parity") + + Returns: + DataFrame with preprocessed data + + Example: + >>> from antibody_training_esm.datasets.jain import load_jain_data + >>> df = load_jain_data(stage="parity") # 86 antibodies (Novo parity) + >>> print(f"Loaded {len(df)} sequences") + """ + dataset = JainDataset() + return dataset.load_data( + full_csv_path=full_csv, sd03_csv_path=sd03_csv, stage=stage + ) diff --git a/src/antibody_training_esm/datasets/mixins/__init__.py b/src/antibody_training_esm/datasets/mixins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df06c8899067b3b05cab69ae8f784265e38ca713 --- /dev/null +++ b/src/antibody_training_esm/datasets/mixins/__init__.py @@ -0,0 +1,2 @@ +from .annotation_mixin import AnnotationMixin +from .fragment_mixin import FragmentMixin diff --git a/src/antibody_training_esm/datasets/mixins/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/datasets/mixins/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fdb393673ffc23622ce0b2d22f4dcafc7d8ce91 Binary files /dev/null and b/src/antibody_training_esm/datasets/mixins/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/datasets/mixins/__pycache__/annotation_mixin.cpython-312.pyc b/src/antibody_training_esm/datasets/mixins/__pycache__/annotation_mixin.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..841ff7453c13802e66ef3a3697b5496b58a1ac7b Binary files /dev/null and b/src/antibody_training_esm/datasets/mixins/__pycache__/annotation_mixin.cpython-312.pyc differ diff --git a/src/antibody_training_esm/datasets/mixins/__pycache__/fragment_mixin.cpython-312.pyc b/src/antibody_training_esm/datasets/mixins/__pycache__/fragment_mixin.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..888bef98e09977f2cb682e08591e81289933d827 Binary files /dev/null and b/src/antibody_training_esm/datasets/mixins/__pycache__/fragment_mixin.cpython-312.pyc differ diff --git a/src/antibody_training_esm/datasets/mixins/annotation_mixin.py b/src/antibody_training_esm/datasets/mixins/annotation_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9d4e7775da6c3658a6d2a78f879a523c1eedb6 --- /dev/null +++ b/src/antibody_training_esm/datasets/mixins/annotation_mixin.py @@ -0,0 +1,120 @@ +""" +Mixin for antibody sequence annotation using ANARCI/riot_na. +""" + +import logging + +import pandas as pd + + +class AnnotationMixin: + """Mixin for ANARCI sequence annotation capabilities.""" + + logger: logging.Logger + + def annotate_sequence( + self, sequence_id: str, sequence: str, chain: str + ) -> dict[str, str] | None: + """ + Annotate a single sequence using ANARCI (IMGT numbering). + + This method wraps riot_na.create_riot_aa() to extract CDR/FWR regions. + + Args: + sequence_id: Unique identifier for the sequence + sequence: Protein sequence to annotate + chain: Chain type ("H" for heavy, "L" for light) + + Returns: + Dictionary with keys: FWR1, CDR1, FWR2, CDR2, FWR3, CDR3, FWR4 + Returns None if annotation fails + """ + try: + # Import riot_na here to avoid dependency issues + from riot_na import create_riot_aa + + # Run ANARCI annotation + result = create_riot_aa(sequence_id, sequence, chain=chain) + + if result is None: + self.logger.warning( + f"ANARCI annotation failed for {sequence_id} ({chain} chain)" + ) + return None + + # Extract regions + annotations = { + "FWR1": result.get("FWR1", ""), + "CDR1": result.get("CDR1", ""), + "FWR2": result.get("FWR2", ""), + "CDR2": result.get("CDR2", ""), + "FWR3": result.get("FWR3", ""), + "CDR3": result.get("CDR3", ""), + "FWR4": result.get("FWR4", ""), + } + + # Validate annotations (should not be empty) + if not any(annotations.values()): + self.logger.warning( + f"All annotations empty for {sequence_id} ({chain} chain)" + ) + return None + + return annotations + + except Exception as e: + self.logger.error(f"Error annotating {sequence_id} ({chain} chain): {e}") + return None + + def annotate_all(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Annotate all sequences in a DataFrame. + + Adds annotation columns for heavy and light chains. + + Args: + df: DataFrame with VH_sequence and optionally VL_sequence columns + + Returns: + DataFrame with annotation columns added + """ + self.logger.info(f"Annotating {len(df)} sequences...") + + # Annotate heavy chains + if "VH_sequence" in df.columns: + self.logger.info("Annotating VH sequences...") + vh_annotations: pd.Series = df.apply( + lambda row: self.annotate_sequence( + row.get("id", f"seq_{row.name}"), row["VH_sequence"], "H" + ) + if pd.notna(row["VH_sequence"]) + else None, + axis=1, + ) + + # Extract annotation fields + for field in ["FWR1", "CDR1", "FWR2", "CDR2", "FWR3", "CDR3", "FWR4"]: + df[f"VH_{field}"] = vh_annotations.apply( + lambda x, f=field: x.get(f, "") if x else "" + ) + + # Annotate light chains (if present) + if "VL_sequence" in df.columns: + self.logger.info("Annotating VL sequences...") + vl_annotations: pd.Series = df.apply( + lambda row: self.annotate_sequence( + row.get("id", f"seq_{row.name}"), row["VL_sequence"], "L" + ) + if pd.notna(row["VL_sequence"]) + else None, + axis=1, + ) + + # Extract annotation fields + for field in ["FWR1", "CDR1", "FWR2", "CDR2", "FWR3", "CDR3", "FWR4"]: + df[f"VL_{field}"] = vl_annotations.apply( + lambda x, f=field: x.get(f, "") if x else "" + ) + + self.logger.info("Annotation complete") + return df diff --git a/src/antibody_training_esm/datasets/mixins/fragment_mixin.py b/src/antibody_training_esm/datasets/mixins/fragment_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed339fc8e5e0a7f88ab364eb48ceddda26124f3 --- /dev/null +++ b/src/antibody_training_esm/datasets/mixins/fragment_mixin.py @@ -0,0 +1,249 @@ +""" +Mixin for fragment handling (statistics, CSV export). +""" + +import logging +from pathlib import Path +from typing import Any + +import pandas as pd + + +class FragmentMixin: + """Mixin for antibody fragment generation.""" + + logger: logging.Logger + output_dir: Path + dataset_name: str + + def get_fragment_types(self) -> list[str]: + """Expected to be implemented by the main class.""" + raise NotImplementedError + + def create_fragments(self, row: pd.Series) -> dict[str, tuple[str, int, str]]: + """ + Create all fragment types from an annotated sequence row. + + Args: + row: DataFrame row with annotation columns + + Returns: + Dictionary mapping fragment_type -> (sequence, label, source) + + Raises: + ValueError: If required annotation columns are missing + """ + fragments = {} + sequence_id = row.get("id", f"seq_{row.name}") + label = row.get("label", 0) + + fragment_types = self.get_fragment_types() + + # Validate that required columns exist for requested fragments + required_cols = set() + if ( + any( + ft in fragment_types + for ft in [ + "VH_only", + "VH+VL", + "H-CDR1", + "H-CDR2", + "H-CDR3", + "H-CDRs", + "H-FWRs", + ] + ) + and "VH_sequence" not in row + ): + required_cols.add("VH_sequence") + if ( + any( + ft in fragment_types + for ft in [ + "VL_only", + "VH+VL", + "L-CDR1", + "L-CDR2", + "L-CDR3", + "L-CDRs", + "L-FWRs", + ] + ) + and "VL_sequence" not in row + ): + required_cols.add("VL_sequence") + + if required_cols: + raise ValueError( + f"Missing required columns for fragment extraction: {sorted(required_cols)}. " + f"Available columns: {sorted(row.index.tolist())}. " + "Did annotation fail?" + ) + + # Helper to concatenate regions + def concat(*regions: Any) -> str: + return "".join(str(r) for r in regions if pd.notna(r) and r != "") + + # Full antibody fragments + if "VH_only" in fragment_types: + fragments["VH_only"] = (row.get("VH_sequence", ""), label, sequence_id) + + if "VL_only" in fragment_types: + fragments["VL_only"] = (row.get("VL_sequence", ""), label, sequence_id) + + if "VH+VL" in fragment_types: + vh = row.get("VH_sequence", "") + vl = row.get("VL_sequence", "") + fragments["VH+VL"] = (concat(vh, vl), label, sequence_id) + + # Heavy chain fragments + if "H-CDR1" in fragment_types: + fragments["H-CDR1"] = (row.get("VH_CDR1", ""), label, sequence_id) + if "H-CDR2" in fragment_types: + fragments["H-CDR2"] = (row.get("VH_CDR2", ""), label, sequence_id) + if "H-CDR3" in fragment_types: + fragments["H-CDR3"] = (row.get("VH_CDR3", ""), label, sequence_id) + + if "H-CDRs" in fragment_types: + h_cdrs = concat( + row.get("VH_CDR1", ""), + row.get("VH_CDR2", ""), + row.get("VH_CDR3", ""), + ) + fragments["H-CDRs"] = (h_cdrs, label, sequence_id) + + if "H-FWRs" in fragment_types: + h_fwrs = concat( + row.get("VH_FWR1", ""), + row.get("VH_FWR2", ""), + row.get("VH_FWR3", ""), + row.get("VH_FWR4", ""), + ) + fragments["H-FWRs"] = (h_fwrs, label, sequence_id) + + # Light chain fragments + if "L-CDR1" in fragment_types: + fragments["L-CDR1"] = (row.get("VL_CDR1", ""), label, sequence_id) + if "L-CDR2" in fragment_types: + fragments["L-CDR2"] = (row.get("VL_CDR2", ""), label, sequence_id) + if "L-CDR3" in fragment_types: + fragments["L-CDR3"] = (row.get("VL_CDR3", ""), label, sequence_id) + + if "L-CDRs" in fragment_types: + l_cdrs = concat( + row.get("VL_CDR1", ""), + row.get("VL_CDR2", ""), + row.get("VL_CDR3", ""), + ) + fragments["L-CDRs"] = (l_cdrs, label, sequence_id) + + if "L-FWRs" in fragment_types: + l_fwrs = concat( + row.get("VL_FWR1", ""), + row.get("VL_FWR2", ""), + row.get("VL_FWR3", ""), + row.get("VL_FWR4", ""), + ) + fragments["L-FWRs"] = (l_fwrs, label, sequence_id) + + # Combined fragments + if "All-CDRs" in fragment_types: + all_cdrs = concat( + row.get("VH_CDR1", ""), + row.get("VH_CDR2", ""), + row.get("VH_CDR3", ""), + row.get("VL_CDR1", ""), + row.get("VL_CDR2", ""), + row.get("VL_CDR3", ""), + ) + fragments["All-CDRs"] = (all_cdrs, label, sequence_id) + + if "All-FWRs" in fragment_types: + all_fwrs = concat( + row.get("VH_FWR1", ""), + row.get("VH_FWR2", ""), + row.get("VH_FWR3", ""), + row.get("VH_FWR4", ""), + row.get("VL_FWR1", ""), + row.get("VL_FWR2", ""), + row.get("VL_FWR3", ""), + row.get("VL_FWR4", ""), + ) + fragments["All-FWRs"] = (all_fwrs, label, sequence_id) + + if "Full" in fragment_types: + full = concat( + row.get("VH_sequence", ""), + row.get("VL_sequence", ""), + ) + fragments["Full"] = (full, label, sequence_id) + + # Nanobody-specific (VHH) + if "VHH_only" in fragment_types: + fragments["VHH_only"] = (row.get("VH_sequence", ""), label, sequence_id) + + return fragments + + def create_fragment_csvs(self, df: pd.DataFrame, suffix: str = "") -> None: + """ + Generate CSV files for all fragment types. + + Creates one CSV file per fragment type with columns: + - id: sequence identifier + - sequence: fragment sequence + - label: binary label (0=specific, 1=non-specific) + - source: original sequence ID + + Args: + df: Annotated DataFrame + suffix: Optional suffix for output filenames (e.g., "_filtered") + """ + self.logger.info("Generating fragment CSVs...") + + fragment_types = self.get_fragment_types() + + # Collect fragments for each type + fragment_data: dict[str, list[dict[str, Any]]] = { + ftype: [] for ftype in fragment_types + } + + for _, row in df.iterrows(): + fragments = self.create_fragments(row) + for ftype, (seq, label, source) in fragments.items(): + if seq: # Skip empty sequences + fragment_data[ftype].append( + { + "id": f"{source}_{ftype}", + "sequence": seq, + "label": label, + "source": source, + } + ) + + # Write CSV files + for ftype, data in fragment_data.items(): + if not data: + self.logger.warning(f"No data for fragment type: {ftype}") + continue + + output_file = self.output_dir / f"{ftype}_{self.dataset_name}{suffix}.csv" + fragment_df = pd.DataFrame(data) + + # Write with metadata header + with open(output_file, "w") as f: + f.write(f"# Dataset: {self.dataset_name}\n") + f.write(f"# Fragment type: {ftype}\n") + f.write(f"# Total sequences: {len(fragment_df)}\n") + f.write( + f"# Label distribution: " + f"{(fragment_df['label'] == 0).sum()} specific, " + f"{(fragment_df['label'] == 1).sum()} non-specific\n" + ) + fragment_df.to_csv(f, index=False) + + self.logger.info( + f" {ftype}: {len(fragment_df)} sequences → {output_file.name}" + ) + + self.logger.info(f"Fragment CSVs written to {self.output_dir}") diff --git a/src/antibody_training_esm/datasets/shehata.py b/src/antibody_training_esm/datasets/shehata.py new file mode 100644 index 0000000000000000000000000000000000000000..2fd973fef351f38b6a5d74be38475da8f8843ab4 --- /dev/null +++ b/src/antibody_training_esm/datasets/shehata.py @@ -0,0 +1,283 @@ +""" +Shehata Dataset Loader + +Loads preprocessed Shehata HIV antibody polyreactivity dataset. + +IMPORTANT: This module is for LOADING preprocessed data, not for running +the preprocessing pipeline. The preprocessing scripts that CREATE the data +are in: preprocessing/shehata/step2_extract_fragments.py + +Dataset characteristics: +- Full antibodies (VH + VL) +- 398 HIV-specific antibodies from 8 donors +- Binary classification based on PSR (Polyreactivity Screening Reagent) scores +- B cell subset metadata (memory, naive, plasmablast) +- 16 fragment types (full antibody) + +Source: +- data/test/shehata/raw/shehata-mmc2.xlsx + +Reference: +- Shehata et al. (2019), "Affinity Maturation Enhances Antibody Specificity but Compromises Conformational Stability" + Supplementary Material mmc2.xlsx +""" + +import logging +from pathlib import Path +from typing import Any + +import pandas as pd +import pandera.pandas as pa + +from antibody_training_esm.schemas.dataset import get_shehata_schema +from antibody_training_esm.settings import settings + +from .base import AntibodyDataset + +SHEHATA_EXCEL_PATH = settings.SHEHATA_EXCEL_PATH +SHEHATA_OUTPUT_DIR = settings.SHEHATA_OUTPUT_DIR + + +class ShehataDataset(AntibodyDataset): + """ + Loader for Shehata HIV antibody dataset. + + This class provides an interface to LOAD preprocessed Shehata dataset files. + It does NOT run the preprocessing pipeline - use preprocessing/shehata/step2_extract_fragments.py for that. + + The Shehata dataset contains HIV-specific antibodies from 8 donors, with PSR scores + measuring polyreactivity. The paper reports 7/398 (1.76%) as non-specific, + corresponding to the 98.24th percentile threshold. + """ + + # Default PSR threshold (98.24th percentile based on paper: 7/398 non-specific) + DEFAULT_PSR_PERCENTILE = 0.9824 + + def __init__( + self, output_dir: Path | None = None, logger: logging.Logger | None = None + ): + """ + Initialize Shehata dataset loader. + + Args: + output_dir: Directory containing preprocessed fragment files + logger: Logger instance + """ + super().__init__( + dataset_name="shehata", + output_dir=output_dir or SHEHATA_OUTPUT_DIR, + logger=logger, + ) + + @classmethod + def get_schema(cls) -> pa.DataFrameSchema: + return get_shehata_schema() + + def get_fragment_types(self) -> list[str]: + """ + Return full antibody fragment types. + + Shehata contains VH + VL sequences, so we generate all 16 fragment types. + + Returns: + List of 16 full antibody fragment types + """ + return self.FULL_ANTIBODY_FRAGMENTS + + def calculate_psr_threshold( + self, + psr_scores: pd.Series, + percentile: float | None = None, + ) -> float: + """ + Calculate PSR score threshold for binary classification. + + Based on paper: "7 out of 398 antibodies characterised as non-specific" + This is 1.76% = 98.24th percentile + + Args: + psr_scores: Series of PSR scores (numeric) + percentile: Percentile to use (default: 0.9824 for 7/398) + + Returns: + PSR threshold value + """ + if percentile is None: + percentile = self.DEFAULT_PSR_PERCENTILE + + threshold = psr_scores.quantile(percentile) + + self.logger.info("\nPSR Score Analysis:") + self.logger.info(f" Valid PSR scores: {psr_scores.notna().sum()}") + self.logger.info(f" Mean: {psr_scores.mean():.4f}") + self.logger.info(f" Median: {psr_scores.median():.4f}") + self.logger.info(f" 75th percentile: {psr_scores.quantile(0.75):.4f}") + self.logger.info(f" 95th percentile: {psr_scores.quantile(0.95):.4f}") + self.logger.info(f" Max: {psr_scores.max():.4f}") + self.logger.info(f"\n PSR = 0: {(psr_scores == 0).sum()} antibodies") + self.logger.info(f" PSR > 0: {(psr_scores > 0).sum()} antibodies") + self.logger.info( + "\n Paper reports: 7/398 non-specific (~1.76%, 98.24th percentile)" + ) + self.logger.info(f" Calculated threshold: {threshold:.4f}") + + return threshold + + def load_data( + self, + excel_path: str | Path | None = None, + psr_threshold: float | None = None, + **_: Any, + ) -> pd.DataFrame: + """ + Load Shehata dataset from Excel file. + + Args: + excel_path: Path to shehata-mmc2.xlsx + psr_threshold: PSR score threshold for binary classification. + If None, calculates 98.24th percentile automatically. + + Returns: + DataFrame with columns: id, VH_sequence, VL_sequence, label, psr_measurement, b_cell_subset + + Raises: + FileNotFoundError: If Excel file not found + """ + # Default path + if excel_path is None: + excel_path = SHEHATA_EXCEL_PATH + + excel_file = Path(excel_path) + if not excel_file.exists(): + raise FileNotFoundError( + f"Shehata Excel file not found: {excel_file}\n" + f"Please ensure mmc2.xlsx is in data/test/shehata/raw/" + ) + + # Load Excel + self.logger.info(f"Reading Excel file: {excel_file}") + df = pd.read_excel(excel_file) + + # Validate dataset is not empty + if len(df) == 0: + raise ValueError( + f"Loaded dataset is empty: {excel_file}\n" + "The Excel file may be corrupted or truncated. " + "Please check the file or re-run preprocessing." + ) + + self.logger.info(f" Loaded {len(df)} rows, {len(df.columns)} columns") + + # Sanitize sequences (remove IMGT gap characters) + self.logger.info("Sanitizing sequences (removing gaps)...") + vh_original = df["VH Protein"].copy() + vl_original = df["VL Protein"].copy() + + df["VH Protein"] = df["VH Protein"].apply( + lambda x: self.sanitize_sequence(x) if pd.notna(x) else x + ) + df["VL Protein"] = df["VL Protein"].apply( + lambda x: self.sanitize_sequence(x) if pd.notna(x) else x + ) + + # Count gaps removed + gaps_vh = sum(str(s).count("-") if pd.notna(s) else 0 for s in vh_original) + gaps_vl = sum(str(s).count("-") if pd.notna(s) else 0 for s in vl_original) + + if gaps_vh > 0 or gaps_vl > 0: + self.logger.info(f" Removed {gaps_vh} gap characters from VH sequences") + self.logger.info(f" Removed {gaps_vl} gap characters from VL sequences") + + # Drop rows without sequences (Excel metadata/footnotes) + before_drop = len(df) + df = df.dropna(subset=["VH Protein", "VL Protein"], how="all") + dropped = before_drop - len(df) + if dropped: + self.logger.info( + f" Dropped {dropped} rows without VH/VL sequences (metadata)" + ) + + # Convert PSR scores to numeric + psr_numeric = pd.to_numeric(df["PSR Score"], errors="coerce") + invalid_psr_mask = psr_numeric.isna() + + if invalid_psr_mask.any(): + dropped_ids = df.loc[invalid_psr_mask, "Clone name"].tolist() + self.logger.warning( + f" Dropping {invalid_psr_mask.sum()} antibodies without numeric PSR scores: " + f"{', '.join(dropped_ids)}" + ) + df = df.loc[~invalid_psr_mask].reset_index(drop=True) + psr_numeric = psr_numeric.loc[~invalid_psr_mask].reset_index(drop=True) + + # Calculate PSR threshold if not provided + if psr_threshold is None: + psr_threshold = self.calculate_psr_threshold(psr_numeric) + else: + self.logger.info(f"Using provided PSR threshold: {psr_threshold}") + + # Create standardized DataFrame + df_output = pd.DataFrame( + { + "id": df["Clone name"], + "VH_sequence": df["VH Protein"], + "VL_sequence": df["VL Protein"], + "label": (psr_numeric > psr_threshold).astype(int), + "psr_measurement": psr_numeric, # Renamed from psr_score to match schema + "b_cell_subset": df["B cell subset"], + } + ) + + # Create 'sequence' column for schema validation (use VH) + if "sequence" not in df_output.columns and "VH_sequence" in df_output.columns: + df_output["sequence"] = df_output["VH_sequence"] + + # Validate with Pandera + df_output = self.validate_dataframe(df_output) + + # Label distribution + self.logger.info("\nLabel distribution:") + label_counts = df_output["label"].value_counts().sort_index() + for label, count in label_counts.items(): + label_name = "Specific" if label == 0 else "Non-specific" + percentage = (count / len(df_output)) * 100 + self.logger.info( + f" {label_name} (label={label}): {count} ({percentage:.1f}%)" + ) + + # B cell subset distribution + self.logger.info("\nB cell subset distribution:") + subset_counts = df_output["b_cell_subset"].value_counts() + for subset, count in subset_counts.items(): + self.logger.info(f" {subset}: {count}") + + return df_output + + +# ========== CONVENIENCE FUNCTIONS FOR LOADING DATA ========== + + +def load_shehata_data( + excel_path: str | None = None, + psr_threshold: float | None = None, +) -> pd.DataFrame: + """ + Convenience function to load preprocessed Shehata dataset. + + IMPORTANT: This loads PREPROCESSED data. To preprocess raw data, use: + preprocessing/shehata/step2_extract_fragments.py + + Args: + excel_path: Path to shehata-mmc2.xlsx + psr_threshold: PSR threshold for classification (None = auto-calculate) + + Returns: + DataFrame with preprocessed data + + Example: + >>> from antibody_training_esm.datasets.shehata import load_shehata_data + >>> df = load_shehata_data() + >>> print(f"Loaded {len(df)} sequences") + """ + dataset = ShehataDataset() + return dataset.load_data(excel_path=excel_path, psr_threshold=psr_threshold) diff --git a/src/antibody_training_esm/evaluation/__init__.py b/src/antibody_training_esm/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/antibody_training_esm/models/__init__.py b/src/antibody_training_esm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37c3dd9f4fa43190e322eb549abf26391ed3e10a --- /dev/null +++ b/src/antibody_training_esm/models/__init__.py @@ -0,0 +1,37 @@ +""" +Pydantic models for runtime validation. + +This package contains schema definitions for: +- Prediction requests/responses +- Configuration validation (Phase 2) +- Dataset schemas (Phase 3) +- Model artifacts (Phase 4) +""" + +from antibody_training_esm.models.config import ( + ClassifierConfig, + DataConfig, + ExperimentConfig, + ModelConfig, + TrainingConfig, + TrainingPipelineConfig, +) +from antibody_training_esm.models.prediction import ( + BatchPredictionRequest, + PredictionRequest, + PredictionResult, +) + +__all__ = [ + # Prediction models + "PredictionRequest", + "BatchPredictionRequest", + "PredictionResult", + # Config models + "ModelConfig", + "DataConfig", + "ClassifierConfig", + "TrainingConfig", + "ExperimentConfig", + "TrainingPipelineConfig", +] diff --git a/src/antibody_training_esm/models/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c23a1811778d2ee41ecc943fe9d3412917bc5b66 Binary files /dev/null and b/src/antibody_training_esm/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/models/__pycache__/artifact.cpython-312.pyc b/src/antibody_training_esm/models/__pycache__/artifact.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd98bb8baaf54e36794ce36c3872ceeaa0878ac2 Binary files /dev/null and b/src/antibody_training_esm/models/__pycache__/artifact.cpython-312.pyc differ diff --git a/src/antibody_training_esm/models/__pycache__/config.cpython-312.pyc b/src/antibody_training_esm/models/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28dd14f68bed0870b713902d7e3268f66bdd6930 Binary files /dev/null and b/src/antibody_training_esm/models/__pycache__/config.cpython-312.pyc differ diff --git a/src/antibody_training_esm/models/__pycache__/prediction.cpython-312.pyc b/src/antibody_training_esm/models/__pycache__/prediction.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36d576b12f6da903e80ae1643b8bcae6e92ac4a7 Binary files /dev/null and b/src/antibody_training_esm/models/__pycache__/prediction.cpython-312.pyc differ diff --git a/src/antibody_training_esm/models/artifact.py b/src/antibody_training_esm/models/artifact.py new file mode 100644 index 0000000000000000000000000000000000000000..23cfb757a67e2159363b2f6c2550e06ec5bc2c4d --- /dev/null +++ b/src/antibody_training_esm/models/artifact.py @@ -0,0 +1,423 @@ +""" +Pydantic models for model artifacts and metrics. + +This module defines the schema for: +1. Saved model metadata (JSON sidecar) +2. Evaluation metrics (accuracy, F1, etc.) +3. Cross-validation results +""" + +from typing import Any, Literal, cast + +import numpy as np +from pydantic import BaseModel, Field + + +class ModelArtifactMetadata(BaseModel): + """ + Metadata for saved model artifacts. + + This model structures the JSON sidecar file that accompanies + NPZ/XGB model files. It enables version checking and parameter + reconstruction. + """ + + # Model architecture + model_name: str = Field( + ..., + description="HuggingFace ESM model ID", + examples=["facebook/esm1v_t33_650M_UR90S_1"], + ) + + model_type: Literal["logistic_regression", "xgboost", "random_forest"] = Field( + ..., + description="Classifier type", + ) + + sklearn_version: str = Field( + ..., + description="scikit-learn version used for training", + examples=["1.3.0"], + ) + + # Classifier configuration (strategy-specific) + classifier: dict[str, Any] = Field( + ..., + description="Full classifier config from to_dict() method", + ) + + # ESM embedding extractor params + esm_model: str = Field( + ..., + description="ESM model name (redundant with model_name, kept for compat)", + ) + + esm_revision: str = Field( + default="main", + description="HuggingFace model revision (commit hash)", + ) + + batch_size: int = Field( + default=16, + ge=1, + description="Batch size for embedding extraction", + ) + + device: str = Field( + default="cpu", + description="Device used during training", + ) + + # Legacy flat fields (LogReg only, for backward compatibility) + C: float | None = Field( + default=None, + description="LogReg: Inverse regularization strength", + ) + + penalty: Literal["l1", "l2"] | None = Field( + default=None, + description="LogReg: Regularization type", + ) + + solver: str | None = Field( + default=None, + description="LogReg: Optimization algorithm", + ) + + # Pydantic handles dict[int, float] keys automatically (converts string keys back to int) + class_weight: Literal["balanced"] | dict[int, float] | None = Field( + default=None, + description="Class weighting strategy", + ) + + max_iter: int | None = Field( + default=None, + description="LogReg: Maximum iterations", + ) + + random_state: int | None = Field( + default=None, + description="Random seed", + ) + + # Optional metrics from training + training_metrics: dict[str, float] | None = Field( + default=None, + description="Metrics from final training run", + ) + + @classmethod + def from_classifier(cls, classifier: Any) -> "ModelArtifactMetadata": + """ + Construct metadata from BinaryClassifier instance. + + Args: + classifier: Trained BinaryClassifier + + Returns: + ModelArtifactMetadata + """ + import sklearn + + strategy_config = classifier.classifier.to_dict() + classifier_type = strategy_config.get("type", "logistic_regression") + + metadata_dict = { + # Model architecture + "model_name": classifier.model_name, + "model_type": classifier_type, + "sklearn_version": sklearn.__version__, + # Classifier config (strategy-specific) + "classifier": strategy_config, + # ESM params + "esm_model": classifier.model_name, + "esm_revision": classifier.revision, + "batch_size": classifier.batch_size, + "device": classifier.device, + } + + # Add legacy flat fields for LogReg (backward compat) + if classifier_type == "logistic_regression": + metadata_dict.update( + { + "C": classifier.C, + "penalty": classifier.penalty, + "solver": classifier.solver, + "class_weight": classifier.class_weight, + "max_iter": classifier.max_iter, + "random_state": classifier.random_state, + } + ) + + return cls.model_validate(metadata_dict) + + def to_classifier_params(self) -> dict[str, Any]: + """ + Extract parameters for BinaryClassifier reconstruction. + + Returns: + Dict of parameters for BinaryClassifier(...) init + """ + params = { + # ESM params + "model_name": self.esm_model, + "device": self.device, + "batch_size": self.batch_size, + "revision": self.esm_revision, + # Classifier params + **self.classifier, + } + + # Overwrite with typed fields for LogReg to ensure correct types (e.g. int keys in dict) + if self.model_type == "logistic_regression": + params.update( + { + "C": self.C, + "penalty": self.penalty, + "solver": self.solver, + "class_weight": self.class_weight, + "max_iter": self.max_iter, + "random_state": self.random_state, + } + ) + + return params + + +class EvaluationMetrics(BaseModel): + """ + Evaluation metrics for a single dataset. + + Used for training set, test set, and cross-validation fold results. + """ + + accuracy: float = Field( + ..., + ge=0.0, + le=1.0, + description="Classification accuracy (0-1)", + ) + + precision: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description="Precision (positive predictive value)", + ) + + recall: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description="Recall (sensitivity, true positive rate)", + ) + + f1: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description="F1 score (harmonic mean of precision and recall)", + ) + + roc_auc: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description="Area under ROC curve", + ) + + # Optional confusion matrix + confusion_matrix: list[list[int]] | None = Field( + default=None, + description="Confusion matrix [[TN, FP], [FN, TP]]", + ) + + # Dataset metadata + dataset_name: str | None = Field( + default=None, + description="Name of evaluated dataset (e.g., 'Jain', 'Training')", + ) + + n_samples: int | None = Field( + default=None, + ge=0, + description="Number of samples in dataset", + ) + + model_config = { + "json_schema_extra": { + "examples": [ + { + "accuracy": 0.6628, + "precision": 0.47, + "recall": 0.63, + "f1": 0.54, + "roc_auc": 0.68, + "confusion_matrix": [[40, 19], [10, 17]], + "dataset_name": "Jain", + "n_samples": 86, + } + ] + } + } + + @classmethod + def from_sklearn_metrics( + cls, + y_true: np.ndarray, + y_pred: np.ndarray, + y_proba: np.ndarray | None = None, + dataset_name: str | None = None, + ) -> "EvaluationMetrics": + """ + Construct metrics from sklearn predictions. + + Args: + y_true: Ground truth labels + y_pred: Predicted labels + y_proba: Predicted probabilities (for ROC-AUC) + dataset_name: Name of dataset + + Returns: + EvaluationMetrics + """ + from sklearn.metrics import ( + accuracy_score, + confusion_matrix, + f1_score, + precision_score, + recall_score, + roc_auc_score, + ) + + metrics_dict = { + "accuracy": float(accuracy_score(y_true, y_pred)), + "precision": float(precision_score(y_true, y_pred, zero_division=0)), + "recall": float(recall_score(y_true, y_pred, zero_division=0)), + "f1": float(f1_score(y_true, y_pred, zero_division=0)), + "dataset_name": dataset_name, + "n_samples": len(y_true), + "confusion_matrix": confusion_matrix(y_true, y_pred).tolist(), + } + + # ROC-AUC requires probabilities + if y_proba is not None: + try: + # Check if y_proba has 2 columns (binary classification) + if y_proba.ndim == 2 and y_proba.shape[1] >= 2: + score = roc_auc_score(y_true, y_proba[:, 1]) + else: + # Fallback for 1D array if passed incorrectly + score = roc_auc_score(y_true, y_proba) + metrics_dict["roc_auc"] = float(score) + except ValueError: + # ROC AUC might fail if only one class is present in y_true + metrics_dict["roc_auc"] = None + + return cls.model_validate(metrics_dict) + + +class CVResults(BaseModel): + """ + Cross-validation results with mean and std for each metric. + + Aggregates metrics across all CV folds. + """ + + cv_accuracy: dict[Literal["mean", "std"], float] = Field( + ..., + description="Mean and std of accuracy across folds", + ) + + cv_precision: dict[Literal["mean", "std"], float] | None = Field( + default=None, + description="Mean and std of precision", + ) + + cv_recall: dict[Literal["mean", "std"], float] | None = Field( + default=None, + description="Mean and std of recall", + ) + + cv_f1: dict[Literal["mean", "std"], float] | None = Field( + default=None, + description="Mean and std of F1 score", + ) + + cv_roc_auc: dict[Literal["mean", "std"], float] | None = Field( + default=None, + description="Mean and std of ROC-AUC", + ) + + n_splits: int = Field( + ..., + ge=2, + description="Number of cross-validation folds", + ) + + # Optional: per-fold results + fold_results: list[EvaluationMetrics] | None = Field( + default=None, + description="Metrics for each individual fold", + ) + + model_config = { + "json_schema_extra": { + "examples": [ + { + "cv_accuracy": {"mean": 0.82, "std": 0.05}, + "cv_precision": {"mean": 0.78, "std": 0.06}, + "cv_recall": {"mean": 0.85, "std": 0.04}, + "cv_f1": {"mean": 0.81, "std": 0.05}, + "cv_roc_auc": {"mean": 0.87, "std": 0.03}, + "n_splits": 10, + } + ] + } + } + + @classmethod + def from_sklearn_cv_results( + cls, + cv_scores: dict[str, list[float] | np.ndarray], + n_splits: int, + ) -> "CVResults": + """ + Construct CVResults from sklearn cross_validate output. + + Args: + cv_scores: Dict like {"test_accuracy": [...], "test_f1": [...]} + n_splits: Number of folds + + Returns: + CVResults + """ + results_dict: dict[str, Any] = {"n_splits": n_splits} + + # Map sklearn metric names to our field names + metric_map = { + "test_accuracy": "cv_accuracy", + "test_precision": "cv_precision", + "test_recall": "cv_recall", + "test_f1": "cv_f1", + "test_roc_auc": "cv_roc_auc", + } + + for sklearn_name, pydantic_name in metric_map.items(): + if sklearn_name in cv_scores: + scores = cv_scores[sklearn_name] + # Handle potential NaN in scores + valid_scores = [s for s in scores if not np.isnan(s)] + + if valid_scores: + results_dict[pydantic_name] = { + "mean": float(np.mean(valid_scores)), + "std": float(np.std(valid_scores)), + } + else: + results_dict[pydantic_name] = { + "mean": 0.0, + "std": 0.0, + } + + return cast(CVResults, cls.model_validate(results_dict)) diff --git a/src/antibody_training_esm/models/config.py b/src/antibody_training_esm/models/config.py new file mode 100644 index 0000000000000000000000000000000000000000..6962b16c7a43a5b6ea41b29a6b7489d76f321a79 --- /dev/null +++ b/src/antibody_training_esm/models/config.py @@ -0,0 +1,305 @@ +from pathlib import Path +from typing import Any, Literal + +from omegaconf import DictConfig, OmegaConf +from pydantic import BaseModel, Field, field_validator + + +class ModelConfig(BaseModel): + """ + ESM protein language model configuration. + + Controls which HuggingFace model to load and execution device. + """ + + name: str = Field( + ..., + description="HuggingFace model ID (e.g., facebook/esm1v_t33_650M_UR90S_1)", + examples=["facebook/esm1v_t33_650M_UR90S_1", "facebook/esm2_t33_650M_UR50D"], + ) + + device: Literal["cpu", "cuda", "mps", "auto"] = Field( + default="auto", + description="Execution device (auto = CUDA > MPS > CPU)", + ) + + revision: str = Field( + default="main", + description="HuggingFace model revision (commit hash for reproducibility)", + ) + + batch_size: int = Field( + default=8, + ge=1, + le=128, + description="Batch size for embedding extraction", + ) + + +class DataConfig(BaseModel): + """ + Dataset configuration. + + Specifies input files and caching directories. + """ + + train_file: Path = Field( + ..., + description="Path to training CSV (must contain 'sequence' and 'label' columns)", + ) + + test_file: Path = Field( + ..., + description="Path to test CSV", + ) + + embeddings_cache_dir: Path = Field( + default=Path("experiments/cache"), + description="Directory for cached ESM embeddings", + ) + + @field_validator("train_file", "test_file") + @classmethod + def validate_file_exists(cls, v: Path) -> Path: + """Ensure file exists at config load time.""" + if not v.exists(): + raise FileNotFoundError(f"Data file not found: {v}") + return v + + @field_validator("embeddings_cache_dir") + @classmethod + def create_cache_dir(cls, v: Path) -> Path: + """Create cache directory if it doesn't exist.""" + v.mkdir(parents=True, exist_ok=True) + return v + + +class ClassifierConfig(BaseModel): + """ + Classifier configuration (strategy-agnostic). + + Supports both LogisticRegression and XGBoost strategies. + """ + + strategy: Literal["logistic_regression", "xgboost"] = Field( + default="logistic_regression", + description="Classification strategy", + ) + + # LogisticRegression params (ignored if strategy=xgboost) + C: float | None = Field( + default=1.0, + gt=0.0, + description="Inverse regularization strength (LogReg only)", + ) + + penalty: Literal["l1", "l2"] | None = Field( + default="l2", + description="Regularization type (LogReg only)", + ) + + solver: Literal["lbfgs", "liblinear", "saga"] | None = Field( + default="lbfgs", + description="Optimization algorithm (LogReg only)", + ) + + class_weight: Literal["balanced"] | dict[int, float] | None = Field( + default="balanced", + description="Class weighting strategy", + ) + + max_iter: int | None = Field( + default=1000, + ge=100, + description="Maximum optimization iterations", + ) + + random_state: int | None = Field( + default=42, + description="Random seed for reproducibility", + ) + + # XGBoost params (ignored if strategy=logistic_regression) + n_estimators: int | None = Field( + default=100, + ge=1, + description="Number of boosting rounds (XGBoost only)", + ) + + max_depth: int | None = Field( + default=6, + ge=1, + le=20, + description="Maximum tree depth (XGBoost only)", + ) + + learning_rate: float | None = Field( + default=0.3, + gt=0.0, + le=1.0, + description="Learning rate (XGBoost only)", + ) + + +class TrainingConfig(BaseModel): + """ + Training orchestration configuration. + + Controls cross-validation, logging, and model persistence. + """ + + n_splits: int = Field( + default=10, + ge=2, + le=20, + description="Number of cross-validation folds", + ) + + random_state: int = Field( + default=42, + description="Random seed used for cross-validation splits", + ) + + stratify: bool = Field( + default=True, + description="Whether to use stratified folds during cross-validation", + ) + + metrics: set[Literal["accuracy", "precision", "recall", "f1", "roc_auc"]] = Field( + default={"accuracy", "precision", "recall", "f1", "roc_auc"}, + description="Metrics to compute during evaluation", + ) + + save_model: bool = Field( + default=True, + description="Whether to save trained model", + ) + + model_save_dir: Path = Field( + default=Path("experiments/checkpoints"), + description="Base directory for saved models", + ) + + model_name: str = Field( + ..., + min_length=1, + description="Name for saved model file (e.g., boughter_vh_esm1v_logreg)", + ) + + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field( + default="INFO", + description="Logging verbosity", + ) + + log_file: str = Field( + default="training.log", + description="Log file name (relative to Hydra output dir)", + ) + + batch_size: int = Field( + default=8, + ge=1, + le=128, + description="Batch size for embedding extraction", + ) + + num_workers: int = Field( + default=4, + ge=0, + description="Number of workers for data loading or preprocessing", + ) + + @field_validator("model_save_dir") + @classmethod + def create_model_dir(cls, v: Path) -> Path: + """Create model save directory if needed.""" + if not v.exists(): + v.mkdir(parents=True, exist_ok=True) + return v + + +class ExperimentConfig(BaseModel): + """ + Experiment tracking metadata. + + Used for organizing Hydra outputs and logging. + """ + + name: str = Field( + ..., + min_length=1, + description="Experiment name (used in Hydra output directory)", + ) + + tags: list[str] = Field( + default_factory=list, + description="Experiment tags for filtering/search", + ) + + description: str | None = Field( + default=None, + description="Human-readable experiment description", + ) + + +class TrainingPipelineConfig(BaseModel): + """ + Root configuration for training pipeline. + + Mirrors Hydra's config.yaml structure. + """ + + model: ModelConfig + data: DataConfig + classifier: ClassifierConfig + training: TrainingConfig + experiment: ExperimentConfig + + # Optional hardware config (added in config.yaml) + hardware: dict[str, Any] | None = Field( + default=None, + description="Hardware-specific overrides (device, num_threads)", + ) + + # Runtime metrics (attached after training) + train_metrics: dict[str, Any] | None = Field( + default=None, + description="Metrics from final training run (attached at runtime)", + exclude=True, # Do not expect this in input config + ) + + model_config = { + "json_schema_extra": { + "title": "Antibody Training Pipeline Configuration", + "description": "Complete configuration for ESM-based antibody training", + } + } + + @classmethod + def from_hydra(cls, cfg: DictConfig) -> "TrainingPipelineConfig": + """ + Convert Hydra DictConfig to Pydantic model. + + This is the main entry point for validation. + """ + # Resolve all interpolations first + OmegaConf.resolve(cfg) + + # Convert to dict (Pydantic doesn't accept DictConfig directly) + config_dict = OmegaConf.to_container(cfg, resolve=True) + + # Backwards compatibility: allow training.batch_size overrides to populate model.batch_size + if isinstance(config_dict, dict): + model_cfg = config_dict.get("model", {}) or {} + training_cfg = config_dict.get("training", {}) or {} + if ( + "batch_size" not in model_cfg + and isinstance(training_cfg, dict) + and "batch_size" in training_cfg + ): + model_cfg["batch_size"] = training_cfg["batch_size"] + config_dict["model"] = model_cfg + + # Validate with Pydantic + result: TrainingPipelineConfig = cls.model_validate(config_dict) + return result diff --git a/src/antibody_training_esm/models/prediction.py b/src/antibody_training_esm/models/prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..157df53dd3a067382b75b73140139b5cb0a19ac5 --- /dev/null +++ b/src/antibody_training_esm/models/prediction.py @@ -0,0 +1,151 @@ +from typing import Literal + +from pydantic import BaseModel, Field, field_validator + +AssayType = Literal["ELISA", "PSR"] + + +class PredictionRequest(BaseModel): + """ + Single sequence prediction request. + + Validates amino acid sequence and optional parameters. + """ + + sequence: str = Field( + ..., + min_length=1, + max_length=2000, + description="Antibody amino acid sequence (VH or VL)", + examples=["QVQLVQSGAEVKKPGASVKVSCKASGYTFT..."], + ) + + threshold: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Classification threshold (0-1)", + ) + + assay_type: AssayType | None = Field( + default=None, + description="Assay type for calibrated thresholds", + ) + + @field_validator("sequence") + @classmethod + def validate_amino_acids(cls, v: str) -> str: + """Validate sequence contains only valid amino acids.""" + # Clean whitespace + cleaned = v.strip().upper() + + if not cleaned: + raise ValueError("Sequence cannot be empty after cleaning") + + # Standard 20 amino acids + X (unknown) + valid_chars = set("ACDEFGHIKLMNPQRSTVWYX") + invalid_chars = set(cleaned) - valid_chars + + if invalid_chars: + raise ValueError( + f"Invalid characters found: {', '.join(sorted(invalid_chars))}. " + f"Only standard amino acids (ACDEFGHIKLMNPQRSTVWY) and X are allowed." + ) + + return cleaned + + model_config = { + "json_schema_extra": { + "examples": [ + { + "sequence": "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYNMHWVRQAPGQGLEWMG", + "threshold": 0.5, + "assay_type": "ELISA", + } + ] + } + } + + +class BatchPredictionRequest(BaseModel): + """ + Batch prediction request for multiple sequences. + + Supports both inline lists and file uploads (future). + """ + + sequences: list[str] = Field( + ..., + min_length=1, + max_length=1000, # Batch size limit + description="List of antibody sequences", + ) + + threshold: float = Field(default=0.5, ge=0.0, le=1.0) + assay_type: AssayType | None = None + + @field_validator("sequences") + @classmethod + def validate_all_sequences(cls, v: list[str]) -> list[str]: + """Validate each sequence in batch.""" + cleaned = [] + errors = [] + + for i, seq in enumerate(v): + try: + # Reuse PredictionRequest validator + request = PredictionRequest(sequence=seq) + cleaned.append(request.sequence) + except ValueError as e: + errors.append(f"Sequence {i + 1}: {e}") + + if errors: + raise ValueError("Batch validation failed:\n" + "\n".join(errors)) + + return cleaned + + +class PredictionResult(BaseModel): + """ + Prediction result for a single sequence. + + Standardizes output format across CLI, Gradio, and future APIs. + """ + + sequence: str = Field(..., description="Input sequence (cleaned)") + + prediction: Literal["specific", "non-specific"] = Field( + ..., + description="Classification result", + ) + + probability: float = Field( + ..., + ge=0.0, + le=1.0, + description="Probability of non-specificity (class 1)", + ) + + threshold: float = Field( + ..., + description="Threshold used for classification", + ) + + assay_type: AssayType | None = Field( + default=None, + description="Assay type if calibrated threshold was used", + ) + + model_config = { + "json_schema_extra": { + "examples": [ + { + "sequence": "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYNMH...", + "prediction": "specific", + "probability": 0.23, + "threshold": 0.5, + "assay_type": "ELISA", + } + ] + } + } diff --git a/src/antibody_training_esm/schemas/__init__.py b/src/antibody_training_esm/schemas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2241ae856958ceb1ae2f4a366da387976a0048d --- /dev/null +++ b/src/antibody_training_esm/schemas/__init__.py @@ -0,0 +1,24 @@ +""" +Pandera schemas for DataFrame validation. + +This package contains schema definitions for: +- Base sequence datasets +- Training datasets (Boughter) +- Test datasets (Jain, Harvey, Shehata) +""" + +from antibody_training_esm.schemas.dataset import ( + get_boughter_schema, + get_harvey_schema, + get_jain_schema, + get_sequence_dataset_schema, + get_shehata_schema, +) + +__all__ = [ + "get_sequence_dataset_schema", + "get_boughter_schema", + "get_jain_schema", + "get_harvey_schema", + "get_shehata_schema", +] diff --git a/src/antibody_training_esm/schemas/__pycache__/__init__.cpython-312.pyc b/src/antibody_training_esm/schemas/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d9044073e1fc54d7af8408fc175c1ab0cc5d0b3 Binary files /dev/null and b/src/antibody_training_esm/schemas/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/antibody_training_esm/schemas/__pycache__/dataset.cpython-312.pyc b/src/antibody_training_esm/schemas/__pycache__/dataset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2ae7a8502bff37a714d54fe7437f0d497c82693 Binary files /dev/null and b/src/antibody_training_esm/schemas/__pycache__/dataset.cpython-312.pyc differ diff --git a/src/antibody_training_esm/schemas/dataset.py b/src/antibody_training_esm/schemas/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..696cf0208da3313e80031942fe2fe8d6a58309cb --- /dev/null +++ b/src/antibody_training_esm/schemas/dataset.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import re + +import pandas as pd +import pandera.backends.pandas # noqa: F401 # registers pandas backend +import pandera.pandas as pa + +VALID_AA = set("ACDEFGHIKLMNPQRSTVWYX") +_UPPERCASE_PATTERN = re.compile(r"^[A-Z]+$") +_NO_GAP_PATTERN = re.compile(r"^[^*.-]+$") + + +def _regex_check(pattern: re.Pattern[str], name: str) -> pa.Check: + return pa.Check( + lambda series: bool(series.str.match(pattern).fillna(False).all()), + name=name, + ) + + +def _length_check(min_value: int, max_value: int, name: str) -> pa.Check: + return pa.Check( + lambda series: bool(series.str.len().between(min_value, max_value).all()), + name=name, + ) + + +def _amino_acid_check(series: pd.Series) -> bool: + return bool(series.dropna().map(lambda seq: set(str(seq)).issubset(VALID_AA)).all()) + + +def _no_gap_check(series: pd.Series) -> bool: + return bool(series.str.match(_NO_GAP_PATTERN).fillna(False).all()) + + +# Base schema for all antibody datasets (production: strict, no NaN labels) +def get_sequence_dataset_schema() -> pa.DataFrameSchema: + return pa.DataFrameSchema( + columns={ + "sequence": pa.Column( + dtype="string", + checks=[ + _regex_check(_UPPERCASE_PATTERN, name="uppercase_letters"), + _length_check(1, 2000, name="length_1_2000"), + pa.Check(_amino_acid_check, name="valid_amino_acids"), + pa.Check(_no_gap_check, name="no_gap_characters"), + ], + nullable=False, + coerce=True, # Auto-convert to string + description="Antibody amino acid sequence (VH, VL, or VHH)", + ), + "label": pa.Column( + dtype="int64", + checks=[ + pa.Check( + lambda series: series.isin([0, 1]).all(), + name="binary_label", + ), + ], + nullable=False, + description="Binary label: 0=specific, 1=non-specific", + ), + }, + strict=False, # Allow extra columns (e.g., id, metadata) + coerce=True, # Auto-coerce types when possible + name="SequenceDataset", + ) + + +# Preprocessing schema (allows nullable labels for held-out/intermediate data) +def get_preprocessing_schema() -> pa.DataFrameSchema: + """ + Schema for preprocessing intermediate files (e.g., Boughter annotated/). + + Allows nullable labels for sequences held out due to quality flags. + For production training/testing, use get_sequence_dataset_schema() instead. + """ + return pa.DataFrameSchema( + columns={ + "sequence": pa.Column( + dtype="string", + checks=[ + _regex_check(_UPPERCASE_PATTERN, name="uppercase_letters"), + _length_check(1, 2000, name="length_1_2000"), + pa.Check(_amino_acid_check, name="valid_amino_acids"), + pa.Check(_no_gap_check, name="no_gap_characters"), + ], + nullable=False, + coerce=True, + description="Antibody amino acid sequence (VH, VL, or VHH)", + ), + "label": pa.Column( + dtype="float64", # float64 to handle NaN + checks=[ + # Only check non-null values are 0 or 1 + pa.Check( + lambda series: series.dropna().isin([0, 1, 0.0, 1.0]).all(), + name="binary_label_when_present", + ), + ], + nullable=True, # Allow NaN for held-out sequences + coerce=True, + description="Binary label: 0=specific, 1=non-specific (nullable for held-out)", + ), + }, + strict=False, + coerce=True, + name="PreprocessingDataset", + ) + + +# Boughter-specific schema (extends base) +def get_boughter_schema() -> pa.DataFrameSchema: + return get_sequence_dataset_schema().add_columns( + { + "id": pa.Column( + dtype="string", + nullable=True, + required=False, + description="Antibody identifier", + ), + # Boughter has additional metadata columns + "vh_sequence": pa.Column( + dtype="string", + nullable=True, + required=False, + description="Heavy chain variable domain (if paired)", + ), + } + ) + + +# Jain-specific schema +def get_jain_schema() -> pa.DataFrameSchema: + return get_sequence_dataset_schema().add_columns( + { + "id": pa.Column( + dtype="string", + nullable=False, + checks=[ + pa.Check.str_length(min_value=1), + ], + description="Antibody INN name (required for Jain)", + ), + "vh_sequence": pa.Column( + dtype="string", + nullable=True, + required=False, + description="VH sequence (Jain has full paired data)", + ), + "vl_sequence": pa.Column( + dtype="string", + nullable=True, + required=False, + description="VL sequence", + ), + } + ) + + +# Jain preprocessing schema (allows nullable labels for full stage) +def get_jain_preprocessing_schema() -> pa.DataFrameSchema: + return get_preprocessing_schema().add_columns( + { + "id": pa.Column( + dtype="string", + nullable=False, + checks=[ + pa.Check.str_length(min_value=1), + ], + description="Antibody INN name (required for Jain)", + ), + "vh_sequence": pa.Column( + dtype="string", + nullable=True, + required=False, + description="VH sequence (Jain has full paired data)", + ), + "vl_sequence": pa.Column( + dtype="string", + nullable=True, + required=False, + description="VL sequence", + ), + } + ) + + +# Harvey-specific schema (VHH only, no light chain) +def get_harvey_schema() -> pa.DataFrameSchema: + return pa.DataFrameSchema( + columns={ + "sequence": pa.Column( + dtype="string", + checks=[ + _regex_check(_UPPERCASE_PATTERN, name="uppercase_letters"), + _length_check(1, 2000, name="length_1_2000"), + pa.Check(_amino_acid_check, name="valid_amino_acids"), + pa.Check(_no_gap_check, name="no_gap_characters"), + ], + nullable=False, + description="Nanobody VHH sequence", + ), + "label": pa.Column( + dtype="int64", + checks=[pa.Check.isin([0, 1])], + nullable=False, + ), + # Harvey has pre-annotated CDRs + "cdr1": pa.Column(dtype="string", nullable=True, required=False), + "cdr2": pa.Column(dtype="string", nullable=True, required=False), + "cdr3": pa.Column(dtype="string", nullable=True, required=False), + }, + strict=False, + coerce=True, + name="HarveyNanobodyDataset", + ) + + +# Shehata-specific schema (paired antibodies with PSR measurements) +def get_shehata_schema() -> pa.DataFrameSchema: + return get_sequence_dataset_schema().add_columns( + { + "psr_measurement": pa.Column( + dtype="float64", + checks=[ + pa.Check.in_range(min_value=0.0, max_value=1.0), + ], + nullable=True, + required=False, + description="PSR assay measurement (0-1 range)", + ), + "vh_sequence": pa.Column( + dtype="string", + nullable=True, + required=False, + ), + "vl_sequence": pa.Column( + dtype="string", + nullable=True, + required=False, + ), + } + ) diff --git a/src/antibody_training_esm/settings.py b/src/antibody_training_esm/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..5c87561eab3e54a8e786b809712f27e8b49cf02a --- /dev/null +++ b/src/antibody_training_esm/settings.py @@ -0,0 +1,176 @@ +from pathlib import Path + +from pydantic import Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class DataSettings(BaseSettings): + """ + Centralized data path configuration. + + Replaces hardcoded paths in preprocessing/paths.py. + Allows override via env vars (e.g. ANTIBODY_DATA_DIR=/tmp/data). + """ + + model_config = SettingsConfigDict( + env_prefix="ANTIBODY_", env_file=".env", extra="ignore" + ) + + PROJECT_ROOT: Path = Field( + default_factory=lambda: Path(__file__).resolve().parents[2] + ) + DATA_DIR: Path = Field(default_factory=lambda: Path("data")) + EXPERIMENTS_DIR: Path = Field(default_factory=lambda: Path("experiments")) + + def _resolve(self, path: Path) -> Path: + """Resolve relative paths against PROJECT_ROOT.""" + return path if path.is_absolute() else self.PROJECT_ROOT / path + + @model_validator(mode="after") + def _normalize_base_dirs(self) -> "DataSettings": + object.__setattr__(self, "DATA_DIR", self._resolve(self.DATA_DIR)) + object.__setattr__(self, "EXPERIMENTS_DIR", self._resolve(self.EXPERIMENTS_DIR)) + return self + + # ============================================================================ + # Base data directories + # ============================================================================ + @property + def DATA_TRAIN_DIR(self) -> Path: + return self._resolve(self.DATA_DIR) / "train" + + @property + def DATA_TEST_DIR(self) -> Path: + return self._resolve(self.DATA_DIR) / "test" + + # ============================================================================ + # Boughter (training set) + # ============================================================================ + @property + def BOUGHTER_DIR(self) -> Path: + return self.DATA_TRAIN_DIR / "boughter" + + @property + def BOUGHTER_RAW_DIR(self) -> Path: + return self.BOUGHTER_DIR / "raw" + + @property + def BOUGHTER_PROCESSED_DIR(self) -> Path: + return self.BOUGHTER_DIR / "processed" + + @property + def BOUGHTER_ANNOTATED_DIR(self) -> Path: + return self.BOUGHTER_DIR / "annotated" + + @property + def BOUGHTER_CANONICAL_DIR(self) -> Path: + return self.BOUGHTER_DIR / "canonical" + + @property + def BOUGHTER_PROCESSED_CSV(self) -> Path: + return self.BOUGHTER_PROCESSED_DIR / "boughter.csv" + + @property + def BOUGHTER_CANONICAL_CSV(self) -> Path: + return self.BOUGHTER_CANONICAL_DIR / "boughter_vh_914.csv" + + # ============================================================================ + # Jain (test set) + # ============================================================================ + @property + def JAIN_DIR(self) -> Path: + return self.DATA_TEST_DIR / "jain" + + @property + def JAIN_RAW_DIR(self) -> Path: + return self.JAIN_DIR / "raw" + + @property + def JAIN_PROCESSED_DIR(self) -> Path: + return self.JAIN_DIR / "processed" + + @property + def JAIN_FRAGMENTS_DIR(self) -> Path: + return self.JAIN_DIR / "fragments" + + @property + def JAIN_CANONICAL_DIR(self) -> Path: + return self.JAIN_DIR / "canonical" + + @property + def JAIN_FULL_CSV(self) -> Path: + return self.JAIN_PROCESSED_DIR / "jain_with_private_elisa_FULL.csv" + + @property + def JAIN_SD03_CSV(self) -> Path: + return self.JAIN_PROCESSED_DIR / "jain_sd03.csv" + + @property + def JAIN_OUTPUT_DIR(self) -> Path: + return self.JAIN_FRAGMENTS_DIR + + # ============================================================================ + # Harvey (test set) + # ============================================================================ + @property + def HARVEY_DIR(self) -> Path: + return self.DATA_TEST_DIR / "harvey" + + @property + def HARVEY_RAW_DIR(self) -> Path: + return self.HARVEY_DIR / "raw" + + @property + def HARVEY_PROCESSED_DIR(self) -> Path: + return self.HARVEY_DIR / "processed" + + @property + def HARVEY_FRAGMENTS_DIR(self) -> Path: + return self.HARVEY_DIR / "fragments" + + @property + def HARVEY_HIGH_POLY_CSV(self) -> Path: + return self.HARVEY_RAW_DIR / "high_polyreactivity_high_throughput.csv" + + @property + def HARVEY_LOW_POLY_CSV(self) -> Path: + return self.HARVEY_RAW_DIR / "low_polyreactivity_high_throughput.csv" + + @property + def HARVEY_OUTPUT_DIR(self) -> Path: + return self.HARVEY_FRAGMENTS_DIR + + # ============================================================================ + # Shehata (test set) + # ============================================================================ + @property + def SHEHATA_DIR(self) -> Path: + return self.DATA_TEST_DIR / "shehata" + + @property + def SHEHATA_RAW_DIR(self) -> Path: + return self.SHEHATA_DIR / "raw" + + @property + def SHEHATA_PROCESSED_DIR(self) -> Path: + return self.SHEHATA_DIR / "processed" + + @property + def SHEHATA_FRAGMENTS_DIR(self) -> Path: + return self.SHEHATA_DIR / "fragments" + + @property + def SHEHATA_CANONICAL_DIR(self) -> Path: + return self.SHEHATA_DIR / "canonical" + + @property + def SHEHATA_EXCEL_PATH(self) -> Path: + return self.SHEHATA_RAW_DIR / "shehata-mmc2.xlsx" + + @property + def SHEHATA_OUTPUT_DIR(self) -> Path: + return self.SHEHATA_FRAGMENTS_DIR + + +# Global instance +settings = DataSettings()