VibecoderMcSwaggins commited on
Commit
6491864
·
1 Parent(s): 1a47b7e

Initial deployment: Antibody non-specificity predictor

Browse files

- ESM-1v (650M) + Logistic Regression
- Trained on Boughter dataset
- Pydantic v2 validation
- Gradio 5.x UI

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +62 -7
  2. app.py +152 -0
  3. experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl +3 -0
  4. pyproject.toml +215 -0
  5. requirements.txt +28 -0
  6. src/antibody_training_esm/__init__.py +0 -0
  7. src/antibody_training_esm/__pycache__/__init__.cpython-312.pyc +0 -0
  8. src/antibody_training_esm/__pycache__/settings.cpython-312.pyc +0 -0
  9. src/antibody_training_esm/cli/__init__.py +10 -0
  10. src/antibody_training_esm/cli/__pycache__/__init__.cpython-312.pyc +0 -0
  11. src/antibody_training_esm/cli/__pycache__/app.cpython-312.pyc +0 -0
  12. src/antibody_training_esm/cli/__pycache__/predict.cpython-312.pyc +0 -0
  13. src/antibody_training_esm/cli/__pycache__/preprocess.cpython-312.pyc +0 -0
  14. src/antibody_training_esm/cli/__pycache__/test.cpython-312.pyc +0 -0
  15. src/antibody_training_esm/cli/__pycache__/train.cpython-312.pyc +0 -0
  16. src/antibody_training_esm/cli/app.py +197 -0
  17. src/antibody_training_esm/cli/predict.py +116 -0
  18. src/antibody_training_esm/cli/preprocess.py +84 -0
  19. src/antibody_training_esm/cli/test.py +155 -0
  20. src/antibody_training_esm/cli/testing/__init__.py +1 -0
  21. src/antibody_training_esm/cli/testing/__pycache__/__init__.cpython-312.pyc +0 -0
  22. src/antibody_training_esm/cli/testing/__pycache__/config.cpython-312.pyc +0 -0
  23. src/antibody_training_esm/cli/testing/__pycache__/data.cpython-312.pyc +0 -0
  24. src/antibody_training_esm/cli/testing/__pycache__/evaluation.cpython-312.pyc +0 -0
  25. src/antibody_training_esm/cli/testing/__pycache__/tester.cpython-312.pyc +0 -0
  26. src/antibody_training_esm/cli/testing/__pycache__/visualization.cpython-312.pyc +0 -0
  27. src/antibody_training_esm/cli/testing/config.py +62 -0
  28. src/antibody_training_esm/cli/testing/data.py +73 -0
  29. src/antibody_training_esm/cli/testing/evaluation.py +134 -0
  30. src/antibody_training_esm/cli/testing/tester.py +384 -0
  31. src/antibody_training_esm/cli/testing/visualization.py +127 -0
  32. src/antibody_training_esm/cli/train.py +42 -0
  33. src/antibody_training_esm/conf/__init__.py +9 -0
  34. src/antibody_training_esm/conf/__pycache__/__init__.cpython-312.pyc +0 -0
  35. src/antibody_training_esm/conf/__pycache__/config_schema.cpython-312.pyc +0 -0
  36. src/antibody_training_esm/conf/classifier/logreg.yaml +12 -0
  37. src/antibody_training_esm/conf/classifier/xgboost.yaml +14 -0
  38. src/antibody_training_esm/conf/config.yaml +36 -0
  39. src/antibody_training_esm/conf/config_schema.py +142 -0
  40. src/antibody_training_esm/conf/data/boughter_jain.yaml +23 -0
  41. src/antibody_training_esm/conf/hardware/default.yaml +5 -0
  42. src/antibody_training_esm/conf/hydra/default.yaml +10 -0
  43. src/antibody_training_esm/conf/model/esm1v.yaml +4 -0
  44. src/antibody_training_esm/conf/model/esm2_650m.yaml +3 -0
  45. src/antibody_training_esm/conf/predict.yaml +26 -0
  46. src/antibody_training_esm/conf/testing/jain_p5e_s2.yaml +7 -0
  47. src/antibody_training_esm/core/__init__.py +19 -0
  48. src/antibody_training_esm/core/__pycache__/__init__.cpython-312.pyc +0 -0
  49. src/antibody_training_esm/core/__pycache__/classifier.cpython-312.pyc +0 -0
  50. src/antibody_training_esm/core/__pycache__/classifier_factory.cpython-312.pyc +0 -0
README.md CHANGED
@@ -1,12 +1,67 @@
1
  ---
2
- title: Antibody Predictor
3
- emoji: 🐨
4
- colorFrom: pink
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.0.0
8
- app_file: app.py
9
  pinned: false
 
 
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Antibody Non-Specificity Predictor
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: "5.0.0"
8
+ app_file: spaces/app.py
9
  pinned: false
10
+ license: mit
11
+ tags:
12
+ - antibody
13
+ - protein
14
+ - ESM
15
+ - gradio
16
+ - polyreactivity
17
+ - machine-learning
18
  ---
19
 
20
+ # 🧬 Antibody Non-Specificity Predictor
21
+
22
+ Predict antibody polyreactivity (non-specificity) from Variable Heavy (VH) or Variable Light (VL) sequences using ESM-1v protein language models.
23
+
24
+ ## Model
25
+
26
+ - **Architecture:** ESM-1v (650M parameters) + Logistic Regression
27
+ - **Training Data:** Boughter dataset (914 antibodies, ELISA polyreactivity)
28
+ - **Methodology:** Sakhnini et al. (2025) - Prediction of Antibody Non-Specificity using PLMs
29
+
30
+ ## Usage
31
+
32
+ 1. Paste your antibody VH or VL amino acid sequence
33
+ 2. Click "🔬 Predict Non-Specificity"
34
+ 3. Get prediction (specific vs non-specific) + probability
35
+
36
+ ## Supported Input
37
+
38
+ - **Valid characters:** Standard amino acids (ACDEFGHIKLMNPQRSTVWY)
39
+ - **Max length:** 2000 amino acids
40
+ - **Auto-cleaning:** Lowercase automatically converted to uppercase
41
+
42
+ ## Examples
43
+
44
+ The app includes example sequences:
45
+ - Standard VH (128aa)
46
+ - Standard VL (107aa)
47
+ - Short VH (Herceptin-like)
48
+
49
+ ## Citation
50
+
51
+ If you use this tool in your research, please cite:
52
+
53
+ ```bibtex
54
+ @article{sakhnini2025antibody,
55
+ title={Prediction of Antibody Non-Specificity using Protein Language Models},
56
+ author={Sakhnini, et al.},
57
+ year={2025}
58
+ }
59
+ ```
60
+
61
+ ## Repository
62
+
63
+ Full source code: [antibody_training_pipeline_ESM](https://github.com/The-Obstacle-Is-The-Way/antibody_training_pipeline_ESM)
64
+
65
+ ## License
66
+
67
+ MIT License - See repository for details
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Spaces Gradio App for Antibody Non-Specificity Prediction
3
+
4
+ Simplified deployment version (no Hydra, no complex dependencies).
5
+ Works on HF Spaces free CPU tier.
6
+
7
+ Local app (src/antibody_training_esm/cli/app.py) remains unchanged.
8
+ """
9
+
10
+ import logging
11
+ import os
12
+
13
+ import gradio as gr
14
+ import torch
15
+ from pydantic import ValidationError
16
+
17
+ from antibody_training_esm.core.prediction import Predictor
18
+ from antibody_training_esm.models.prediction import PredictionRequest
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # HF Spaces environment detection
25
+ IS_HF_SPACE = os.getenv("SPACE_ID") is not None
26
+
27
+ # Model path (either local or downloaded from HF Hub)
28
+ MODEL_PATH = os.getenv(
29
+ "MODEL_PATH", "experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl"
30
+ )
31
+
32
+ # ESM model name
33
+ MODEL_NAME = "facebook/esm1v_t33_650M_UR90S_1"
34
+
35
+ # Force CPU for HF Spaces free tier
36
+ DEVICE = "cpu"
37
+
38
+ # Load model globally (HF Spaces best practice)
39
+ logger.info(f"Loading model from {MODEL_PATH}...")
40
+ predictor = Predictor(
41
+ model_name=MODEL_NAME, classifier_path=MODEL_PATH, device=DEVICE, config_path=None
42
+ )
43
+
44
+ # Warm up model
45
+ try:
46
+ logger.info("Warming up model...")
47
+ predictor.predict_single("QVQL")
48
+ logger.info("Model ready!")
49
+ except Exception as e:
50
+ logger.warning(f"Warmup failed (non-fatal): {e}")
51
+
52
+
53
+ def predict_sequence(sequence: str) -> tuple[str, str]:
54
+ """
55
+ Prediction function for Gradio interface.
56
+
57
+ Args:
58
+ sequence: Antibody amino acid sequence
59
+
60
+ Returns:
61
+ Tuple of (prediction, probability)
62
+ """
63
+ try:
64
+ # Validate with Pydantic
65
+ request = PredictionRequest(sequence=sequence)
66
+
67
+ # Log request
68
+ logger.info(f"Processing sequence: length={len(request.sequence)}")
69
+
70
+ # Predict
71
+ result = predictor.predict_single(request)
72
+
73
+ # Format probability
74
+ prob_percent = f"{result.probability:.1%}"
75
+
76
+ return result.prediction, prob_percent
77
+
78
+ except ValidationError as e:
79
+ # User-friendly error message
80
+ error_msg = e.errors()[0]["msg"]
81
+ raise gr.Error(error_msg) from e
82
+ except torch.cuda.OutOfMemoryError as e:
83
+ logger.error("GPU OOM during inference")
84
+ raise gr.Error(
85
+ "Server overloaded (GPU OOM). Please try again in a moment."
86
+ ) from e
87
+ except Exception as e:
88
+ logger.exception("Unexpected prediction failure")
89
+ raise gr.Error(f"Prediction failed: {str(e)}") from e
90
+
91
+
92
+ # Example sequences
93
+ examples = [
94
+ [
95
+ "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYNMHWVRQAPGQGLEWMGGIYPGDSDTRYSPSFQGQVTISADKSISTAYLQWSSLKASDTAMYYCARSTYYGGDWYFNVWGQGTLVTVSS"
96
+ ], # Standard VH
97
+ [
98
+ "DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPLTFGGGTKVEIK"
99
+ ], # Standard VL
100
+ [
101
+ "EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARSWGQGTLVTVSS"
102
+ ], # Short VH
103
+ ]
104
+
105
+ # Create Gradio interface
106
+ iface = gr.Interface(
107
+ fn=predict_sequence,
108
+ inputs=gr.TextArea(
109
+ lines=7,
110
+ max_lines=20,
111
+ max_length=2000,
112
+ label="Antibody Sequence (VH or VL)",
113
+ placeholder="Paste amino acid sequence here (e.g., QVQL...)",
114
+ info="Supported characters: Standard amino acids (ACDEFGHIKLMNPQRSTVWY).",
115
+ show_copy_button=True,
116
+ ),
117
+ outputs=[
118
+ gr.Textbox(label="Prediction", show_copy_button=True),
119
+ gr.Textbox(label="Probability of Non-Specificity", show_copy_button=True),
120
+ ],
121
+ title="🧬 Antibody Non-Specificity Predictor",
122
+ description=(
123
+ "Predict antibody polyreactivity (non-specificity) from Variable Heavy (VH) "
124
+ "or Variable Light (VL) sequences using ESM-1v protein language models.\n\n"
125
+ "**Model:** ESM-1v (650M parameters) + Logistic Regression\n"
126
+ "**Training:** Boughter dataset (914 antibodies, ELISA polyreactivity)\n"
127
+ "**Citation:** Sakhnini et al. (2025) - Prediction of Antibody Non-Specificity using PLMs"
128
+ ),
129
+ article=(
130
+ f"**Model:** {MODEL_NAME}\n"
131
+ f"**Device:** {DEVICE}\n"
132
+ f"**Environment:** {'Hugging Face Spaces' if IS_HF_SPACE else 'Local'}"
133
+ ),
134
+ examples=examples,
135
+ cache_examples=False, # Don't cache on HF Spaces (saves disk)
136
+ flagging_mode="never",
137
+ analytics_enabled=False,
138
+ submit_btn="🔬 Predict Non-Specificity",
139
+ clear_btn="🗑️ Clear",
140
+ )
141
+
142
+ # Enable queue for concurrency
143
+ iface.queue(default_concurrency_limit=2, max_size=10)
144
+
145
+ # Launch app
146
+ if __name__ == "__main__":
147
+ iface.launch(
148
+ server_name="0.0.0.0", # Required for HF Spaces
149
+ server_port=7860,
150
+ share=False,
151
+ show_api=False, # No public REST API
152
+ )
experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4f77cadfd0ccf3a12c24ce142a91c82b4481d5153a0af662ac4b05a78ef6670
3
+ size 11314
pyproject.toml ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [tool.hatch.build.targets.wheel]
6
+ packages = ["src/antibody_training_esm"]
7
+ include = [
8
+ "src/antibody_training_esm/conf/**/*.yaml",
9
+ "src/antibody_training_esm/conf/**/*.py",
10
+ ]
11
+
12
+ [tool.hatch.build.targets.sdist]
13
+ # Source distribution must include all source files + configs
14
+ include = [
15
+ "src/antibody_training_esm/**/*.py",
16
+ "src/antibody_training_esm/conf/**/*.yaml",
17
+ "tests/**/*.py",
18
+ "README.md",
19
+ "pyproject.toml",
20
+ "LICENSE",
21
+ ]
22
+
23
+ [project]
24
+ name = "antibody-training-esm"
25
+ version = "0.7.0"
26
+ description = "Professional antibody training pipeline using ESM protein language models"
27
+ license = {text = "Apache-2.0"}
28
+ requires-python = ">=3.12"
29
+ dependencies = [
30
+ "authlib>=1.6.5",
31
+ "biopython>=1.80",
32
+ "brotli>=1.2.0",
33
+ "datasets>=4.2.0",
34
+ "h2>=4.3.0",
35
+ "hydra-core>=1.3.2",
36
+ "jupyterlab>=4.4.9",
37
+ "matplotlib>=3.7.0",
38
+ "more-itertools",
39
+ "numpy>=1.24.0",
40
+ "pandas>=2.0.0",
41
+ "plotly",
42
+ "pyparsing>=3.0.0",
43
+ "PyYAML>=6.0.0",
44
+ "riot_na",
45
+ "scikit-learn>=1.3.0",
46
+ "scipy>=1.10.0",
47
+ "seaborn>=0.12.0",
48
+ "torch>=2.6.0",
49
+ "tqdm>=4.65.0",
50
+ "transformers>=4.30.0",
51
+ "xgboost>=2.0.0",
52
+ "gradio>=4.0.0",
53
+ ]
54
+
55
+ [project.optional-dependencies]
56
+ validation = [
57
+ "pydantic>=2.10.0", # Stable v2 release
58
+ "pydantic-settings>=2.6.0", # For future config management
59
+ "pandera>=0.20.0", # Phase 3: Data Integrity
60
+ ]
61
+ dev = [
62
+ # Testing
63
+ "pytest>=8.3.0",
64
+ "pytest-cov>=6.0.0",
65
+ "pytest-xdist>=3.6.0",
66
+ "pytest-sugar>=1.0.0",
67
+
68
+ # Linting & Formatting
69
+ "ruff>=0.8.0",
70
+
71
+ # Type Checking
72
+ "mypy>=1.13.0",
73
+ "pandas-stubs>=2.2.0",
74
+
75
+ # Security
76
+ "bandit[toml]>=1.7.0",
77
+
78
+ # Pre-commit
79
+ "pre-commit>=4.0.0",
80
+
81
+ # Documentation
82
+ "mkdocs>=1.6.0",
83
+ "mkdocs-material>=9.5.0",
84
+ "mkdocstrings[python]>=0.26.0",
85
+ "mkdocs-gen-files>=0.5.0",
86
+ "mkdocs-literate-nav>=0.6.0",
87
+ "mkdocs-section-index>=0.3.0",
88
+ "pymdown-extensions>=10.0.0",
89
+ ]
90
+
91
+ [project.scripts]
92
+ # Point directly to Hydra-decorated function to enable config group overrides
93
+ # (antibody-train model=esm2_650m classifier=xgboost now works correctly)
94
+ antibody-train = "antibody_training_esm.core.trainer:main"
95
+ antibody-test = "antibody_training_esm.cli.test:main"
96
+ antibody-preprocess = "antibody_training_esm.cli.preprocess:main"
97
+ antibody-predict = "antibody_training_esm.cli.predict:main"
98
+ antibody-app = "antibody_training_esm.cli.app:main"
99
+
100
+ [tool.ruff]
101
+ target-version = "py312"
102
+ line-length = 88
103
+
104
+ [tool.ruff.lint]
105
+ select = [
106
+ "E", # pycodestyle errors
107
+ "W", # pycodestyle warnings
108
+ "F", # pyflakes
109
+ "I", # isort
110
+ "B", # flake8-bugbear
111
+ "C4", # flake8-comprehensions
112
+ "UP", # pyupgrade
113
+ "ARG", # flake8-unused-arguments
114
+ "SIM", # flake8-simplify
115
+ ]
116
+ ignore = [
117
+ "E501", # line too long (handled by formatter)
118
+ ]
119
+
120
+ [tool.ruff.lint.per-file-ignores]
121
+ "__init__.py" = ["F401"]
122
+ "tests/**/*" = ["ARG"]
123
+ "experiments/**/*" = ["ALL"]
124
+ "reference_repos/**/*" = ["ALL"]
125
+
126
+ [tool.ruff.format]
127
+ quote-style = "double"
128
+ indent-style = "space"
129
+
130
+ [tool.mypy]
131
+ python_version = "3.12"
132
+ warn_return_any = true
133
+ warn_unused_configs = true
134
+ disallow_untyped_defs = true
135
+ ignore_missing_imports = true
136
+ exclude = [
137
+ "experiments/",
138
+ "reference_repos/",
139
+ "site/", # MkDocs generated documentation
140
+ "tests/unit/cli/test_train.py", # Legacy CLI tests (deprecated)
141
+ ]
142
+
143
+ [tool.pytest.ini_options]
144
+ # Pytest Configuration (canonical source - pytest.ini deleted for single source of truth)
145
+ testpaths = ["tests"]
146
+ python_files = ["test_*.py"]
147
+ python_classes = ["Test*"]
148
+ python_functions = ["test_*"]
149
+ addopts = [
150
+ # Output formatting
151
+ "-v",
152
+ "--tb=short",
153
+ "--strict-markers",
154
+ "-ra",
155
+ # Coverage reporting
156
+ "--cov=src/antibody_training_esm",
157
+ "--cov-report=html",
158
+ "--cov-report=term-missing",
159
+ # Performance
160
+ "--maxfail=10",
161
+ ]
162
+ markers = [
163
+ "unit: Unit tests (fast, no I/O) - Core business logic",
164
+ "integration: Integration tests (medium speed, some I/O) - Component interactions",
165
+ "e2e: End-to-end tests (slow, full pipeline) - Full workflows",
166
+ "slow: Tests that take >1s to run",
167
+ "gpu: Tests that require GPU (skip in CI with: -m 'not gpu')",
168
+ "legacy: Legacy tests for backward compatibility (deprecated, will be removed)",
169
+ ]
170
+ filterwarnings = [
171
+ # sklearn deprecation warnings
172
+ "ignore:.*__sklearn_tags__.*:DeprecationWarning:sklearn.utils._tags",
173
+ # sklearn convergence warnings (expected with small test datasets)
174
+ "ignore:.*lbfgs failed to converge.*:sklearn.exceptions.ConvergenceWarning",
175
+ "ignore:.*lbfgs failed to converge.*:UserWarning:sklearn.linear_model._logistic",
176
+ # sklearn scoring warnings (expected when testing edge cases)
177
+ "ignore:.*Scoring failed.*:UserWarning:sklearn.model_selection._validation",
178
+ # sklearn undefined metric warnings (expected with edge case test data)
179
+ "ignore:.*Precision is ill-defined.*:sklearn.exceptions.UndefinedMetricWarning",
180
+ "ignore:.*Precision is ill-defined.*:UserWarning:sklearn.metrics._classification",
181
+ # pytest collection warnings (TestConfig is a dataclass, not a test class)
182
+ "ignore:.*cannot collect test class.*TestConfig.*:pytest.PytestCollectionWarning",
183
+ # General deprecation warnings
184
+ "ignore::DeprecationWarning",
185
+ "ignore::PendingDeprecationWarning",
186
+ ]
187
+
188
+ [tool.coverage.run]
189
+ source = ["src"]
190
+ omit = [
191
+ "tests/*",
192
+ "experiments/*",
193
+ "reference_repos/*",
194
+ "**/__pycache__/*",
195
+ ".venv/*",
196
+ "**/conftest.py",
197
+ ]
198
+ branch = true
199
+
200
+ [tool.coverage.report]
201
+ precision = 2
202
+ exclude_lines = [
203
+ "pragma: no cover",
204
+ "def __repr__",
205
+ "raise AssertionError",
206
+ "raise NotImplementedError",
207
+ "if __name__ == .__main__.:",
208
+ "if TYPE_CHECKING:",
209
+ ]
210
+
211
+ [dependency-groups]
212
+ dev = [
213
+ "openpyxl>=3.1.5",
214
+ "types-pyyaml>=6.0.12.20250915",
215
+ ]
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces Requirements
2
+ # Minimal dependencies for antibody prediction demo
3
+
4
+ # Core ML
5
+ torch>=2.0.0
6
+ transformers>=4.30.0
7
+ scikit-learn>=1.3.0
8
+ scipy>=1.10.0
9
+ joblib>=1.3.0
10
+
11
+ # Data handling
12
+ pandas>=2.0.0
13
+ numpy>=1.24.0
14
+
15
+ # Configuration
16
+ omegaconf>=2.3.0
17
+
18
+ # Validation
19
+ pydantic>=2.0.0
20
+
21
+ # Gradio UI
22
+ gradio>=5.0.0
23
+
24
+ # Progress bars
25
+ tqdm>=4.65.0
26
+
27
+ # Install local package (antibody_training_esm)
28
+ .
src/antibody_training_esm/__init__.py ADDED
File without changes
src/antibody_training_esm/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (205 Bytes). View file
 
src/antibody_training_esm/__pycache__/settings.cpython-312.pyc ADDED
Binary file (9.61 kB). View file
 
src/antibody_training_esm/cli/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI Module
3
+
4
+ Professional command-line interfaces for antibody training pipeline:
5
+ - antibody-train: Model training
6
+ - antibody-test: Model evaluation
7
+ - antibody-preprocess: Dataset preprocessing
8
+ """
9
+
10
+ __all__ = []
src/antibody_training_esm/cli/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (440 Bytes). View file
 
src/antibody_training_esm/cli/__pycache__/app.cpython-312.pyc ADDED
Binary file (7.8 kB). View file
 
src/antibody_training_esm/cli/__pycache__/predict.cpython-312.pyc ADDED
Binary file (5.09 kB). View file
 
src/antibody_training_esm/cli/__pycache__/preprocess.cpython-312.pyc ADDED
Binary file (3.6 kB). View file
 
src/antibody_training_esm/cli/__pycache__/test.cpython-312.pyc ADDED
Binary file (6.49 kB). View file
 
src/antibody_training_esm/cli/__pycache__/train.cpython-312.pyc ADDED
Binary file (1.29 kB). View file
 
src/antibody_training_esm/cli/app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains the Gradio app for the antibody non-specificity prediction pipeline.
3
+ """
4
+
5
+ import logging
6
+ import platform
7
+ from pathlib import Path
8
+
9
+ import gradio as gr
10
+ import hydra
11
+ import torch
12
+ from omegaconf import DictConfig
13
+ from pydantic import ValidationError
14
+
15
+ from antibody_training_esm.core.prediction import Predictor
16
+ from antibody_training_esm.models.prediction import PredictionRequest
17
+
18
+ # Configure logging
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def launch_gradio_app(cfg: DictConfig) -> None:
23
+ """
24
+ Launches the Gradio web UI for antibody prediction.
25
+
26
+ This function sets up a Gradio interface that allows users to input an
27
+ antibody sequence and receive a prediction for its non-specificity.
28
+
29
+ Args:
30
+ cfg: The Hydra configuration object.
31
+ """
32
+ # Set log level from config
33
+ logging.basicConfig(
34
+ level=getattr(logging, cfg.gradio.log_level.upper(), logging.INFO)
35
+ )
36
+
37
+ # Robust Device & Threading Configuration
38
+ # -------------------------------------------------------------------------
39
+ # 1. Determine the optimal device for inference
40
+ # - Prefer CUDA if available (Linux/Windows GPU boxes)
41
+ # - Force CPU on macOS if MPS is detected to avoid Gradio+MPS SegFaults
42
+ # - Default to configured value otherwise
43
+ device = cfg.model.get("device", "cpu")
44
+
45
+ if platform.system() == "Darwin" and device == "mps":
46
+ logger.warning(
47
+ "macOS detected. Forcing CPU for Gradio app stability (MPS workaround)."
48
+ )
49
+ device = "cpu"
50
+
51
+ # 2. Configure Threading to prevent OpenMP SegFaults on macOS
52
+ # - On macOS/CPU, PyTorch's OpenMP runtime can crash inside Gradio threads.
53
+ # - We restrict it to 1 thread to ensure stability.
54
+ # - Linux/CUDA systems remain untouched and can use full parallelism.
55
+ if platform.system() == "Darwin" and device == "cpu":
56
+ logger.warning(
57
+ "macOS/CPU detected. Setting torch.set_num_threads(1) to prevent OpenMP crashes."
58
+ )
59
+ torch.set_num_threads(1)
60
+
61
+ if cfg.classifier.path is None:
62
+ raise ValueError(
63
+ "Classifier path must be specified via command-line override:\n"
64
+ " classifier.path=experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl"
65
+ )
66
+ classifier_path = Path(cfg.classifier.path)
67
+ if not classifier_path.exists():
68
+ raise FileNotFoundError(
69
+ f"Classifier file not found at {classifier_path}. "
70
+ "Train a model (e.g., `make train`) or download a published checkpoint first."
71
+ )
72
+
73
+ # Instantiate the predictor
74
+ config_path = getattr(cfg.classifier, "config_path", None)
75
+ predictor = Predictor(
76
+ model_name=cfg.model.name,
77
+ classifier_path=cfg.classifier.path,
78
+ device=device,
79
+ config_path=config_path,
80
+ )
81
+
82
+ # Warm-up: Run a dummy prediction to load the model into memory eagerly
83
+ try:
84
+ logger.info("Warming up model with dummy prediction...")
85
+ predictor.predict_single("QVQL")
86
+ logger.info("Model warmed up and ready.")
87
+ except Exception as e:
88
+ logger.warning(f"Model warm-up failed (non-fatal): {e}")
89
+
90
+ def predict_sequence(sequence: str) -> tuple[str, str]:
91
+ """
92
+ Prediction function for the Gradio interface.
93
+
94
+ Args:
95
+ sequence: The antibody sequence to predict.
96
+
97
+ Returns:
98
+ A tuple containing the prediction string and the formatted probability.
99
+ """
100
+ try:
101
+ # Validate with Pydantic (replaces old validate_input)
102
+ request = PredictionRequest(sequence=sequence)
103
+
104
+ # Log request (observability)
105
+ logger.info(f"Processing: length={len(request.sequence)}")
106
+
107
+ # Predict (returns PydanticResult)
108
+ result = predictor.predict_single(request)
109
+
110
+ # Format probability
111
+ prob_percent = f"{result.probability:.1%}"
112
+
113
+ return result.prediction, prob_percent
114
+
115
+ except ValidationError as e:
116
+ # Extract first error message for user-friendly display
117
+ error_msg = e.errors()[0]["msg"]
118
+ raise gr.Error(error_msg) from e
119
+ except torch.cuda.OutOfMemoryError as e:
120
+ logger.error("GPU OOM during inference")
121
+ raise gr.Error(
122
+ "Server overloaded (GPU OOM). Please try again in a moment."
123
+ ) from e
124
+ except Exception as e:
125
+ logger.exception("Unexpected prediction failure")
126
+ raise gr.Error(f"Prediction failed: {str(e)}") from e
127
+
128
+ # Example sequences (Diverse set)
129
+ examples = [
130
+ [
131
+ "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYNMHWVRQAPGQGLEWMGGIYPGDSDTRYSPSFQGQVTISADKSISTAYLQWSSLKASDTAMYYCARSTYYGGDWYFNVWGQGTLVTVSS"
132
+ ], # Standard VH
133
+ [
134
+ "DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPLTFGGGTKVEIK"
135
+ ], # Standard VL
136
+ [
137
+ "EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARSWGQGTLVTVSS"
138
+ ], # Short VH (Herceptin-like)
139
+ ]
140
+
141
+ # Create the Gradio interface
142
+ iface = gr.Interface(
143
+ fn=predict_sequence,
144
+ inputs=gr.TextArea(
145
+ lines=7,
146
+ max_lines=20,
147
+ max_length=2000,
148
+ label="Antibody Sequence (VH or VL)",
149
+ placeholder="Paste amino acid sequence here (e.g., QVQL...)",
150
+ info="Supported characters: Standard amino acids (ACDEFGHIKLMNPQRSTVWY).",
151
+ show_copy_button=True,
152
+ ),
153
+ outputs=[
154
+ gr.Textbox(label="Prediction", show_copy_button=True),
155
+ gr.Textbox(label="Probability of Non-Specificity", show_copy_button=True),
156
+ ],
157
+ title="Antibody Non-Specificity Predictor",
158
+ description=(
159
+ "Enter an antibody Variable Heavy (VH) or Variable Light (VL) sequence "
160
+ "to predict its non-specificity (polyreactivity)."
161
+ ),
162
+ article=f"Model: {cfg.model.name} | Device: {device}",
163
+ examples=examples,
164
+ cache_examples=True,
165
+ flagging_mode="never",
166
+ analytics_enabled=False,
167
+ submit_btn="Predict Non-Specificity",
168
+ )
169
+
170
+ # Enable queueing for concurrency management
171
+ """
172
+ Queue Configuration:
173
+ - concurrency_limit: Based on available VRAM (approx 3GB per ESM-1v inference).
174
+ - max_size: Prevents unbounded queue growth under load.
175
+ """
176
+ iface.queue(
177
+ default_concurrency_limit=cfg.gradio.queue.concurrency_limit,
178
+ max_size=cfg.gradio.queue.max_size,
179
+ )
180
+
181
+ # Launch the app with hardened settings
182
+ iface.launch(
183
+ server_name=cfg.gradio.server_name,
184
+ server_port=cfg.gradio.server_port,
185
+ share=cfg.gradio.share,
186
+ show_api=False,
187
+ )
188
+
189
+
190
+ @hydra.main(config_path="../conf", config_name="predict", version_base=None)
191
+ def main(cfg: DictConfig) -> None:
192
+ """Main function to run the Gradio app."""
193
+ launch_gradio_app(cfg)
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
src/antibody_training_esm/cli/predict.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ from typing import cast
4
+
5
+ import hydra
6
+ import pandas as pd
7
+ from omegaconf import DictConfig
8
+ from pydantic import ValidationError
9
+
10
+ from antibody_training_esm.core.config import SEQUENCE_PREVIEW_LENGTH
11
+ from antibody_training_esm.core.prediction import Predictor, run_prediction
12
+ from antibody_training_esm.models.prediction import AssayType, PredictionRequest
13
+
14
+
15
+ def predict_sequence_cli(
16
+ sequence: str, threshold: float, assay_type: AssayType | None, cfg: DictConfig
17
+ ) -> None:
18
+ """CLI prediction with Pydantic validation."""
19
+ config_path = getattr(cfg.classifier, "config_path", None)
20
+
21
+ # Instantiate predictor (loading model)
22
+ try:
23
+ predictor = Predictor(
24
+ model_name=cfg.model.name,
25
+ classifier_path=cfg.classifier.path,
26
+ config_path=config_path,
27
+ )
28
+ except Exception as e:
29
+ print(f"Error loading model: {e}")
30
+ sys.exit(1)
31
+
32
+ try:
33
+ request = PredictionRequest(
34
+ sequence=sequence,
35
+ threshold=threshold,
36
+ assay_type=assay_type,
37
+ )
38
+ result = predictor.predict_single(request)
39
+
40
+ # Print formatted output
41
+ print(
42
+ f"Sequence: {result.sequence[:SEQUENCE_PREVIEW_LENGTH]}..."
43
+ if len(result.sequence) > SEQUENCE_PREVIEW_LENGTH
44
+ else f"Sequence: {result.sequence}"
45
+ )
46
+ print(f"Prediction: {result.prediction}")
47
+ print(f"Probability: {result.probability:.2%}")
48
+
49
+ except ValidationError as e:
50
+ print("❌ Validation Error:")
51
+ for error in e.errors():
52
+ # loc is a tuple, e.g. ('sequence',)
53
+ loc = error["loc"][0] if error["loc"] else "root"
54
+ print(f" - {loc}: {error['msg']}")
55
+ sys.exit(1)
56
+
57
+
58
+ @hydra.main(config_path="../conf", config_name="predict", version_base=None)
59
+ def main(cfg: DictConfig) -> None:
60
+ """Main function to run the prediction CLI."""
61
+
62
+ # Check for single sequence prediction mode
63
+ sequence = getattr(cfg, "sequence", None)
64
+ if sequence:
65
+ threshold = getattr(cfg, "threshold", 0.5)
66
+ assay_type = cast(AssayType | None, getattr(cfg, "assay_type", None))
67
+ predict_sequence_cli(sequence, threshold, assay_type, cfg)
68
+ return
69
+
70
+ # Validate required arguments for batch mode
71
+ if cfg.input_file is None:
72
+ raise ValueError(
73
+ "Input file must be specified via command-line override: `input_file=...`"
74
+ )
75
+
76
+ if cfg.classifier.path is None:
77
+ raise ValueError(
78
+ "Classifier path must be specified via command-line override:\n"
79
+ " classifier.path=experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl\n"
80
+ " # OR for production models (.npz):\n"
81
+ " classifier.path=experiments/.../model.npz classifier.config_path=.../model_config.json\n"
82
+ "\nExample usage:\n"
83
+ " uv run antibody-predict \\\n"
84
+ " input_file=data/test.csv \\\n"
85
+ " output_file=predictions.csv \\\n"
86
+ " classifier.path=path/to/model.pkl"
87
+ )
88
+ classifier_path = Path(cfg.classifier.path)
89
+ if not classifier_path.exists():
90
+ raise FileNotFoundError(
91
+ f"Classifier file not found at {classifier_path}. "
92
+ "Train a model (e.g., `make train`) or download a published checkpoint first."
93
+ )
94
+
95
+ try:
96
+ # Load input data
97
+ input_df = pd.read_csv(cfg.input_file)
98
+
99
+ # Run prediction
100
+ output_df = run_prediction(input_df, cfg)
101
+
102
+ # Save output data
103
+ output_df.to_csv(cfg.output_file, index=False)
104
+
105
+ print(f"Predictions saved to {cfg.output_file}")
106
+
107
+ except FileNotFoundError:
108
+ print(f"Error: Input file not found at {cfg.input_file}")
109
+ exit(1)
110
+ except Exception as e:
111
+ print(f"An error occurred: {e}")
112
+ exit(1)
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()
src/antibody_training_esm/cli/preprocess.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Preprocessing CLI
3
+
4
+ Professional command-line interface for dataset preprocessing.
5
+ """
6
+
7
+ import argparse
8
+ import sys
9
+
10
+
11
+ def main() -> int:
12
+ """
13
+ Main entry point for preprocessing CLI.
14
+
15
+ This CLI does NOT run preprocessing - it only provides guidance on which
16
+ preprocessing scripts to use. Preprocessing is handled by specialized
17
+ scripts that are the Single Source of Truth (SSOT).
18
+ """
19
+ parser = argparse.ArgumentParser(
20
+ description="Antibody dataset preprocessing guidance",
21
+ formatter_class=argparse.RawDescriptionHelpFormatter,
22
+ epilog="""
23
+ NOTE: This CLI does NOT run preprocessing. It provides guidance on which
24
+ preprocessing scripts to use. Each dataset has unique requirements and the
25
+ scripts maintain bit-for-bit parity with published methods.
26
+ """,
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--dataset",
31
+ "-d",
32
+ type=str,
33
+ required=True,
34
+ choices=["jain", "harvey", "shehata", "boughter"],
35
+ help="Dataset to get preprocessing guidance for",
36
+ )
37
+
38
+ args = parser.parse_args()
39
+
40
+ try:
41
+ print("\n⚠️ The 'antibody-preprocess' CLI is not implemented")
42
+ print(
43
+ "\nDataset preprocessing is handled by specialized scripts, not this CLI."
44
+ )
45
+ print(
46
+ "These scripts are the authoritative source of truth for data transformation."
47
+ )
48
+ print(f"\nFor {args.dataset} dataset, use:")
49
+
50
+ script_paths = {
51
+ "jain": "preprocessing/jain/step2_preprocess_p5e_s2.py",
52
+ "harvey": "preprocessing/harvey/step2_extract_fragments.py",
53
+ "shehata": "preprocessing/shehata/step2_extract_fragments.py",
54
+ "boughter": "preprocessing/boughter/stage2_stage3_annotation_qc.py",
55
+ }
56
+
57
+ script = script_paths.get(args.dataset)
58
+ if script:
59
+ print(f" python {script}")
60
+
61
+ print("\nWhy use scripts instead of this CLI?")
62
+ print(" • Scripts are Single Source of Truth (SSOT) for preprocessing")
63
+ print(
64
+ " • Each dataset has unique requirements (DNA translation, PSR thresholds, etc.)"
65
+ )
66
+ print(" • Scripts maintain bit-for-bit parity with published methods")
67
+ print(" • CLI is for loading preprocessed data, not creating it")
68
+
69
+ print("\nFor more information:")
70
+ print(" • See src/antibody_training_esm/datasets/README.md")
71
+ print(" • See docs/boughter/boughter_data_sources.md (dataset-specific)")
72
+
73
+ return 0
74
+
75
+ except KeyboardInterrupt:
76
+ print("\n❌ Error: Interrupted by user", file=sys.stderr)
77
+ return 1
78
+ except Exception as e:
79
+ print(f"\n❌ Error: {e}", file=sys.stderr)
80
+ return 1
81
+
82
+
83
+ if __name__ == "__main__":
84
+ sys.exit(main())
src/antibody_training_esm/cli/test.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test CLI for Antibody Classification Pipeline
4
+
5
+ Professional command-line interface for testing trained antibody classifiers:
6
+ 1. Load trained models from pickle files
7
+ 2. Evaluate on test datasets with performance metrics
8
+ 3. Generate confusion matrices and comprehensive logging
9
+
10
+ Usage:
11
+ antibody-test --model experiments/checkpoints/antibody_classifier.pkl --data sample_data.csv
12
+ antibody-test --config test_config.yaml
13
+ antibody-test --model m1.pkl m2.pkl --data d1.csv d2.csv
14
+ """
15
+
16
+ import argparse
17
+ import sys
18
+
19
+ from antibody_training_esm.cli.testing.config import (
20
+ TestConfig,
21
+ create_sample_test_config,
22
+ load_config_file,
23
+ )
24
+ from antibody_training_esm.cli.testing.tester import ModelTester
25
+
26
+
27
+ def main() -> int:
28
+ """Main entry point for antibody-test CLI"""
29
+ parser = argparse.ArgumentParser(
30
+ description="Testing for antibody classification models",
31
+ formatter_class=argparse.RawDescriptionHelpFormatter,
32
+ epilog="""
33
+ Examples:
34
+ # Test single model on single dataset (auto-detects threshold from dataset name)
35
+ antibody-test --model experiments/checkpoints/antibody_classifier.pkl --data sample_data.csv
36
+
37
+ # Test on PSR dataset with auto-detected threshold (0.5495 for Harvey/Shehata)
38
+ antibody-test --model model.pkl --data data/test/harvey/fragments/VHH_only_harvey.csv
39
+
40
+ # Test multiple models on multiple datasets
41
+ antibody-test --model experiments/checkpoints/model1.pkl experiments/checkpoints/model2.pkl --data dataset1.csv dataset2.csv
42
+
43
+ # Use configuration file
44
+ antibody-test --config test_config.yaml
45
+
46
+ # Override device, batch size, and threshold
47
+ antibody-test --config test_config.yaml --device cuda --batch-size 64 --threshold 0.6
48
+
49
+ # Create sample configuration
50
+ antibody-test --create-config
51
+ """,
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--model", nargs="+", help="Path(s) to trained model pickle files"
56
+ )
57
+ parser.add_argument("--data", nargs="+", help="Path(s) to test dataset CSV files")
58
+ parser.add_argument("--config", help="Path to test configuration YAML file")
59
+ parser.add_argument(
60
+ "--output-dir",
61
+ default="./experiments/benchmarks",
62
+ help="Output directory for results",
63
+ )
64
+ parser.add_argument(
65
+ "--device",
66
+ choices=["cpu", "cuda", "mps"],
67
+ help="Device to use for inference (overrides config)",
68
+ )
69
+ parser.add_argument(
70
+ "--batch-size",
71
+ type=int,
72
+ help="Batch size for embedding extraction (overrides config)",
73
+ )
74
+ parser.add_argument(
75
+ "--threshold",
76
+ type=float,
77
+ help="Manual decision threshold override (default: auto-detect from dataset name). "
78
+ "Use 0.5 for ELISA datasets (Boughter, Jain) or 0.5495 for PSR datasets (Harvey, Shehata).",
79
+ )
80
+ parser.add_argument(
81
+ "--sequence-column",
82
+ type=str,
83
+ help="Column name for sequences in dataset (default: 'sequence', overrides config)",
84
+ )
85
+ parser.add_argument(
86
+ "--label-column",
87
+ type=str,
88
+ help="Column name for labels in dataset (default: 'label', overrides config)",
89
+ )
90
+ parser.add_argument(
91
+ "--create-config", action="store_true", help="Create sample configuration file"
92
+ )
93
+
94
+ args = parser.parse_args()
95
+
96
+ # Create sample config if requested
97
+ if args.create_config:
98
+ create_sample_test_config()
99
+ return 0
100
+
101
+ # Load configuration
102
+ if args.config:
103
+ config = load_config_file(args.config)
104
+ else:
105
+ if not args.model or not args.data:
106
+ parser.error("Either --config or both --model and --data must be specified")
107
+
108
+ config = TestConfig(
109
+ model_paths=args.model, data_paths=args.data, output_dir=args.output_dir
110
+ )
111
+
112
+ # Override config with command line arguments
113
+ if args.device:
114
+ config.device = args.device
115
+ if args.batch_size:
116
+ config.batch_size = args.batch_size
117
+ if args.threshold:
118
+ config.threshold = args.threshold
119
+ if args.sequence_column:
120
+ config.sequence_column = args.sequence_column
121
+ if args.label_column:
122
+ config.label_column = args.label_column
123
+
124
+ # Run testing
125
+ try:
126
+ tester = ModelTester(config)
127
+ results = tester.run_comprehensive_test()
128
+
129
+ print(f"\n{'=' * 60}")
130
+ print("TESTING COMPLETED SUCCESSFULLY!")
131
+ print(f"{'=' * 60}")
132
+ print(f"Results saved to: {config.output_dir}")
133
+
134
+ # Print summary
135
+ for dataset_name, dataset_results in results.items():
136
+ print(f"\nDataset: {dataset_name}")
137
+ print("-" * 40)
138
+ for model_name, model_results in dataset_results.items():
139
+ print(f"Model: {model_name}")
140
+ if "test_scores" in model_results:
141
+ for metric, value in model_results["test_scores"].items():
142
+ print(f" {metric}: {value:.4f}")
143
+
144
+ return 0
145
+
146
+ except KeyboardInterrupt:
147
+ print("Error during testing: Interrupted by user", file=sys.stderr)
148
+ return 1
149
+ except Exception as e:
150
+ print(f"Error during testing: {e}", file=sys.stderr)
151
+ return 1
152
+
153
+
154
+ if __name__ == "__main__":
155
+ sys.exit(main())
src/antibody_training_esm/cli/testing/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Test CLI package."""
src/antibody_training_esm/cli/testing/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (249 Bytes). View file
 
src/antibody_training_esm/cli/testing/__pycache__/config.cpython-312.pyc ADDED
Binary file (2.62 kB). View file
 
src/antibody_training_esm/cli/testing/__pycache__/data.cpython-312.pyc ADDED
Binary file (3.47 kB). View file
 
src/antibody_training_esm/cli/testing/__pycache__/evaluation.cpython-312.pyc ADDED
Binary file (5.38 kB). View file
 
src/antibody_training_esm/cli/testing/__pycache__/tester.cpython-312.pyc ADDED
Binary file (17.2 kB). View file
 
src/antibody_training_esm/cli/testing/__pycache__/visualization.cpython-312.pyc ADDED
Binary file (5.21 kB). View file
 
src/antibody_training_esm/cli/testing/config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management for the testing pipeline."""
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import yaml
6
+
7
+ from antibody_training_esm.core.config import DEFAULT_BATCH_SIZE
8
+
9
+
10
+ @dataclass
11
+ class TestConfig:
12
+ """Configuration for testing pipeline"""
13
+
14
+ model_paths: list[str]
15
+ data_paths: list[str]
16
+ sequence_column: str = "sequence" # Column name for sequences in dataset
17
+ label_column: str = "label" # Column name for labels in dataset
18
+ output_dir: str = "./experiments/benchmarks"
19
+ metrics: list[str] | None = None
20
+ save_predictions: bool = True
21
+ batch_size: int = DEFAULT_BATCH_SIZE # Batch size for embedding extraction
22
+ device: str = "mps" # Device to use for inference [cuda, cpu, mps] - MUST match training config
23
+ threshold: float | None = (
24
+ None # Manual threshold override (None = auto-detect from dataset name)
25
+ )
26
+
27
+ def __post_init__(self) -> None:
28
+ if self.metrics is None:
29
+ self.metrics = [
30
+ "accuracy",
31
+ "precision",
32
+ "recall",
33
+ "f1",
34
+ "roc_auc",
35
+ "pr_auc",
36
+ ]
37
+
38
+
39
+ def load_config_file(config_path: str) -> TestConfig:
40
+ """Load test configuration from YAML file"""
41
+ with open(config_path) as f:
42
+ config_dict = yaml.safe_load(f)
43
+
44
+ return TestConfig(**config_dict)
45
+
46
+
47
+ def create_sample_test_config() -> None:
48
+ """Create a sample test configuration file"""
49
+ sample_config = {
50
+ "model_paths": ["./experiments/checkpoints/antibody_classifier.pkl"],
51
+ "data_paths": ["./sample_data.csv"],
52
+ "sequence_column": "sequence",
53
+ "label_column": "label",
54
+ "output_dir": "./experiments/benchmarks",
55
+ "metrics": ["accuracy", "precision", "recall", "f1", "roc_auc", "pr_auc"],
56
+ "save_predictions": True,
57
+ }
58
+
59
+ with open("test_config.yaml", "w") as f:
60
+ yaml.dump(sample_config, f, default_flow_style=False)
61
+
62
+ print("Sample test configuration created: test_config.yaml")
src/antibody_training_esm/cli/testing/data.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset loading and validation utilities."""
2
+
3
+ import logging
4
+ import os
5
+
6
+ import pandas as pd
7
+
8
+ from antibody_training_esm.cli.testing.config import TestConfig
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def load_dataset(data_path: str, config: TestConfig) -> tuple[list[str], list[int]]:
14
+ """
15
+ Load dataset from CSV file using configured column names.
16
+
17
+ Args:
18
+ data_path: Path to the CSV file.
19
+ config: Test configuration object containing column names.
20
+
21
+ Returns:
22
+ Tuple of (sequences, labels).
23
+ """
24
+ logger.info(f"Loading dataset from {data_path}")
25
+
26
+ if not os.path.exists(data_path):
27
+ raise FileNotFoundError(f"Dataset file not found: {data_path}")
28
+
29
+ # Defensive: Handle legacy files with comment headers
30
+ # New files (post-HF cleanup) are standard CSVs without comments
31
+ df = pd.read_csv(data_path, comment="#")
32
+
33
+ sequence_col = config.sequence_column
34
+ label_col = config.label_column
35
+
36
+ if sequence_col not in df.columns:
37
+ raise ValueError(
38
+ f"Sequence column '{sequence_col}' not found in dataset. Available columns: {list(df.columns)}"
39
+ )
40
+ if label_col not in df.columns:
41
+ raise ValueError(
42
+ f"Label column '{label_col}' not found in dataset. Available columns: {list(df.columns)}"
43
+ )
44
+
45
+ # CRITICAL VALIDATION: Check for NaN labels (P0 bug fix)
46
+ nan_count = df[label_col].isna().sum()
47
+ if nan_count > 0:
48
+ raise ValueError(
49
+ f"CRITICAL: Dataset contains {nan_count} NaN labels! "
50
+ f"This will corrupt evaluation metrics. "
51
+ f"Please use the curated canonical test file (e.g., "
52
+ f"data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv with no NaNs)."
53
+ )
54
+
55
+ # For Jain test sets, validate expected size (allow legacy 94 + canonical 86)
56
+ if "jain" in data_path.lower() and "test" in data_path.lower():
57
+ expected_sizes = {94, 86}
58
+ if len(df) not in expected_sizes:
59
+ raise ValueError(
60
+ f"Jain test set has {len(df)} antibodies but expected one of {sorted(expected_sizes)}. "
61
+ f"Using the wrong test set will produce invalid metrics. "
62
+ f"Please use the correct curated file (preferred: "
63
+ f"data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv)."
64
+ )
65
+
66
+ sequences = df[sequence_col].tolist()
67
+ labels = df[label_col].tolist()
68
+
69
+ logger.info(
70
+ f"Loaded {len(sequences)} samples from {data_path} (sequence_col='{sequence_col}', label_col='{label_col}')"
71
+ )
72
+ logger.info(f" Label distribution: {pd.Series(labels).value_counts().to_dict()}")
73
+ return sequences, labels
src/antibody_training_esm/cli/testing/evaluation.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metric calculation and model evaluation utilities."""
2
+
3
+ import logging
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ from sklearn.metrics import (
8
+ classification_report,
9
+ confusion_matrix,
10
+ )
11
+
12
+ from antibody_training_esm.core.classifier import BinaryClassifier
13
+ from antibody_training_esm.models.artifact import EvaluationMetrics
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def detect_assay_type(dataset_name: str) -> str | None:
19
+ """
20
+ Auto-detect assay type from dataset name for threshold selection
21
+
22
+ Args:
23
+ dataset_name: Name of the dataset (e.g., "VH_only_jain", "VHH_only_harvey")
24
+
25
+ Returns:
26
+ 'ELISA' for ELISA-based datasets (Boughter, Jain)
27
+ 'PSR' for PSR-based datasets (Harvey, Shehata)
28
+ None if unable to detect
29
+
30
+ Notes:
31
+ Novo Nordisk (Sakhnini et al. 2025, Section 2.7):
32
+ "Antibodies characterised by the PSR assay appear to be on a different
33
+ non-specificity spectrum than that from the non-specificity ELISA assay."
34
+
35
+ PSR datasets require threshold=0.5495 for optimal performance.
36
+ ELISA datasets use standard threshold=0.5.
37
+ """
38
+ dataset_lower = dataset_name.lower()
39
+
40
+ # PSR-based datasets (Harvey, Shehata)
41
+ if any(marker in dataset_lower for marker in ["harvey", "shehata"]):
42
+ return "PSR"
43
+
44
+ # ELISA-based datasets (Boughter, Jain)
45
+ if any(marker in dataset_lower for marker in ["boughter", "jain"]):
46
+ return "ELISA"
47
+
48
+ # Unable to detect - will use default threshold
49
+ return None
50
+
51
+
52
+ def evaluate_pretrained(
53
+ model: BinaryClassifier,
54
+ X: np.ndarray,
55
+ y: np.ndarray,
56
+ model_name: str,
57
+ dataset_name: str,
58
+ _metrics_list: list[str] | None = None,
59
+ threshold_override: float | None = None,
60
+ ) -> dict[str, Any]:
61
+ """
62
+ Evaluate pretrained model directly on test set (no retraining)
63
+
64
+ Args:
65
+ model: The trained BinaryClassifier.
66
+ X: Embeddings (features).
67
+ y: True labels.
68
+ model_name: Name of the model for logging.
69
+ dataset_name: Name of the dataset for logging.
70
+ _metrics_list: List of metrics to calculate (default: all).
71
+ threshold_override: Optional manual threshold.
72
+
73
+ Returns:
74
+ Dictionary of results including scores, predictions, and reports.
75
+ Contains 'metrics' key with EvaluationMetrics object.
76
+ """
77
+ logger.info(f"Evaluating pretrained model {model_name} on {dataset_name}")
78
+
79
+ # Determine threshold: manual override > auto-detect > default 0.5
80
+ if threshold_override is not None:
81
+ # Manual override via CLI
82
+ threshold = threshold_override
83
+ logger.info(f"Using manual threshold override: {threshold}")
84
+ else:
85
+ # Auto-detect assay type from dataset name
86
+ assay_type = detect_assay_type(dataset_name)
87
+ if assay_type is not None:
88
+ threshold = model.ASSAY_THRESHOLDS[assay_type]
89
+ logger.info(
90
+ f"Auto-detected assay type: {assay_type} → threshold={threshold} "
91
+ f"(Dataset: {dataset_name})"
92
+ )
93
+ else:
94
+ threshold = 0.5
95
+ logger.warning(
96
+ f"Unable to auto-detect assay type for '{dataset_name}'. "
97
+ f"Using default threshold={threshold}. "
98
+ f"For optimal results, specify --threshold or use standard dataset names."
99
+ )
100
+
101
+ # Get predictions using the pretrained model with appropriate threshold
102
+ y_pred = model.predict(
103
+ X, threshold=threshold, assay_type=None
104
+ ) # threshold already determined
105
+ y_proba = model.predict_proba(X)[:, 1]
106
+
107
+ # Create Pydantic metrics
108
+ eval_metrics = EvaluationMetrics.from_sklearn_metrics(
109
+ y,
110
+ y_pred,
111
+ y_proba.reshape(-1, 1) if y_proba.ndim == 1 else y_proba,
112
+ dataset_name=dataset_name,
113
+ )
114
+
115
+ # Calculate legacy results for compatibility with visualization tools
116
+ results = {
117
+ "metrics": eval_metrics, # Store Pydantic model
118
+ "test_scores": eval_metrics.model_dump(
119
+ exclude={"confusion_matrix", "dataset_name", "n_samples"}
120
+ ),
121
+ "predictions": {"y_true": y, "y_pred": y_pred, "y_proba": y_proba},
122
+ "confusion_matrix": confusion_matrix(y, y_pred),
123
+ "classification_report": classification_report(y, y_pred, output_dict=True),
124
+ }
125
+
126
+ # Log results
127
+ logger.info(f"Test results for {model_name} on {dataset_name}:")
128
+ logger.info(f" Accuracy: {eval_metrics.accuracy:.4f}")
129
+ if eval_metrics.f1 is not None:
130
+ logger.info(f" F1: {eval_metrics.f1:.4f}")
131
+ if eval_metrics.roc_auc is not None:
132
+ logger.info(f" ROC-AUC: {eval_metrics.roc_auc:.4f}")
133
+
134
+ return results
src/antibody_training_esm/cli/testing/tester.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Model orchestration logic."
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import pickle # nosec B403
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ from antibody_training_esm.cli.testing.config import TestConfig
15
+ from antibody_training_esm.cli.testing.data import load_dataset
16
+ from antibody_training_esm.cli.testing.evaluation import evaluate_pretrained
17
+ from antibody_training_esm.cli.testing.visualization import (
18
+ plot_confusion_matrix,
19
+ save_detailed_results,
20
+ )
21
+ from antibody_training_esm.core.classifier import BinaryClassifier
22
+ from antibody_training_esm.core.config import DEFAULT_BATCH_SIZE
23
+ from antibody_training_esm.core.directory_utils import (
24
+ extract_classifier_shortname,
25
+ extract_model_shortname,
26
+ get_hierarchical_test_results_dir,
27
+ )
28
+ from antibody_training_esm.core.embeddings import ESMEmbeddingExtractor
29
+
30
+
31
+ class ModelTester:
32
+ """Model testing orchestrator"""
33
+
34
+ def __init__(self, config: TestConfig):
35
+ self.config = config
36
+ self.logger = self._setup_logging()
37
+ self.results: dict[str, Any] = {}
38
+ self.cached_embedding_files: list[str] = [] # Track cached files for cleanup
39
+
40
+ # Create output directory
41
+ os.makedirs(config.output_dir, exist_ok=True)
42
+
43
+ def _setup_logging(self) -> logging.Logger:
44
+ """Setup logging configuration"""
45
+ # Create output directory if it doesn't exist
46
+ os.makedirs(self.config.output_dir, exist_ok=True)
47
+
48
+ log_file = os.path.join(
49
+ self.config.output_dir,
50
+ f"test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log",
51
+ )
52
+
53
+ logging.basicConfig(
54
+ level=logging.INFO,
55
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
56
+ handlers=[logging.FileHandler(log_file), logging.StreamHandler()],
57
+ )
58
+
59
+ return logging.getLogger(__name__)
60
+
61
+ def load_model(self, model_path: str) -> BinaryClassifier:
62
+ """Load trained model from pickle file"""
63
+ self.logger.info(f"Loading model from {model_path}")
64
+
65
+ if not os.path.exists(model_path):
66
+ raise FileNotFoundError(f"Model file not found: {model_path}")
67
+
68
+ with open(model_path, "rb") as f:
69
+ model = pickle.load(f) # nosec B301
70
+
71
+ if not isinstance(model, BinaryClassifier):
72
+ raise ValueError(f"Expected BinaryClassifier, got {type(model)}")
73
+
74
+ # Update device if different from config
75
+ if (
76
+ hasattr(model, "embedding_extractor")
77
+ and model.embedding_extractor.device != self.config.device
78
+ ):
79
+ self.logger.warning(
80
+ f"Device mismatch: model trained on {model.embedding_extractor.device}, "
81
+ f"test config specifies {self.config.device}. Recreating extractor..."
82
+ )
83
+
84
+ # CRITICAL: Explicit cleanup to prevent semaphore leaks (P0 bug fix)
85
+ old_device = str(model.embedding_extractor.device)
86
+ old_extractor = model.embedding_extractor
87
+
88
+ # Delete old extractor before creating new one
89
+ del model.embedding_extractor
90
+ del old_extractor
91
+
92
+ # Clear device-specific GPU cache
93
+ if old_device.startswith("cuda"):
94
+ torch.cuda.empty_cache()
95
+ elif old_device.startswith("mps"):
96
+ torch.mps.empty_cache()
97
+
98
+ self.logger.info(f"Cleaned up old extractor on {old_device}")
99
+
100
+ # NOW create new extractor (no leak)
101
+ batch_size = getattr(model, "batch_size", DEFAULT_BATCH_SIZE)
102
+ revision = getattr(model, "revision", "main")
103
+ model.embedding_extractor = ESMEmbeddingExtractor(
104
+ model.model_name, self.config.device, batch_size, revision=revision
105
+ )
106
+ model.device = self.config.device
107
+
108
+ self.logger.info(f"Created new extractor on {self.config.device}")
109
+
110
+ # Update batch_size if different from config
111
+ if (
112
+ hasattr(model, "embedding_extractor")
113
+ and model.embedding_extractor.batch_size != self.config.batch_size
114
+ ):
115
+ self.logger.info(
116
+ f"Updating batch_size from {model.embedding_extractor.batch_size} to {self.config.batch_size}"
117
+ )
118
+ model.embedding_extractor.batch_size = self.config.batch_size
119
+
120
+ self.logger.info(
121
+ f"Model loaded successfully: {model_path} on device: {model.embedding_extractor.device}"
122
+ )
123
+ return model
124
+
125
+ def embed_sequences(
126
+ self,
127
+ sequences: list[str],
128
+ model: BinaryClassifier,
129
+ dataset_name: str,
130
+ output_dir: str,
131
+ ) -> np.ndarray:
132
+ """Extract embeddings for sequences using the model's embedding extractor"""
133
+ # Ensure output directory exists before file I/O
134
+ os.makedirs(output_dir, exist_ok=True)
135
+
136
+ cache_file = os.path.join(output_dir, f"{dataset_name}_test_embeddings.pkl")
137
+
138
+ # Track this file for cleanup
139
+ if cache_file not in self.cached_embedding_files:
140
+ self.cached_embedding_files.append(cache_file)
141
+
142
+ # Try to load from cache
143
+ if os.path.exists(cache_file):
144
+ try:
145
+ self.logger.info(f"Loading cached embeddings from {cache_file}")
146
+ with open(cache_file, "rb") as f:
147
+ embeddings: np.ndarray = pickle.load(f) # nosec B301
148
+
149
+ # Validate shape and type
150
+ if not isinstance(embeddings, np.ndarray):
151
+ raise ValueError(f"Invalid cache data type: {type(embeddings)}")
152
+ if embeddings.ndim != 2:
153
+ raise ValueError(f"Invalid embedding shape: {embeddings.shape}")
154
+
155
+ if len(embeddings) == len(sequences):
156
+ self.logger.info(f"Loaded {len(embeddings)} cached embeddings")
157
+ return embeddings
158
+ else:
159
+ self.logger.warning(
160
+ "Cached embeddings size mismatch, recomputing..."
161
+ )
162
+
163
+ except (pickle.UnpicklingError, EOFError, ValueError, AttributeError) as e:
164
+ self.logger.warning(
165
+ f"Failed to load cached embeddings from {cache_file}: {e}. "
166
+ "Recomputing embeddings..."
167
+ )
168
+ # Fall through to recomputation below
169
+
170
+ # Extract embeddings
171
+ self.logger.info(f"Extracting embeddings for {len(sequences)} sequences...")
172
+ embeddings = model.embedding_extractor.extract_batch_embeddings(sequences)
173
+
174
+ # Cache embeddings
175
+ with open(cache_file, "wb") as f:
176
+ pickle.dump(embeddings, f)
177
+ self.logger.info(f"Embeddings cached to {cache_file}")
178
+
179
+ return embeddings
180
+
181
+ def cleanup_cached_embeddings(self) -> None:
182
+ """Delete cached embedding files"""
183
+ self.logger.info("Cleaning up cached embedding files...")
184
+ for cache_file in self.cached_embedding_files:
185
+ if os.path.exists(cache_file):
186
+ try:
187
+ os.remove(cache_file)
188
+ self.logger.info(f"Deleted cached embeddings: {cache_file}")
189
+ except Exception as e:
190
+ self.logger.warning(f"Failed to delete {cache_file}: {e}")
191
+
192
+ def _compute_output_directory(
193
+ self,
194
+ model_path: str | None,
195
+ dataset_name: str,
196
+ ) -> str:
197
+ """Compute output directory (hierarchical if model config available, else flat)."""
198
+ if model_path is None:
199
+ self.logger.warning("No model path provided, using flat output structure")
200
+ return self.config.output_dir
201
+
202
+ # Try to load model config JSON
203
+ model_config_path = (
204
+ Path(model_path)
205
+ .with_suffix("")
206
+ .with_name(Path(model_path).stem + "_config.json")
207
+ )
208
+
209
+ if not model_config_path.exists():
210
+ self.logger.info(
211
+ f"Model config not found at {model_config_path}, using flat output structure"
212
+ )
213
+ return self.config.output_dir
214
+
215
+ try:
216
+ with open(model_config_path) as f:
217
+ model_config = json.load(f)
218
+
219
+ model_name = model_config.get("model_name") or model_config.get(
220
+ "esm_model", ""
221
+ )
222
+ if not model_name:
223
+ raise ValueError("Model config missing 'model_name' or 'esm_model'")
224
+
225
+ classifier_config = model_config.get("classifier", {})
226
+
227
+ # Use shared utility for hierarchical path generation
228
+ hierarchical_path = get_hierarchical_test_results_dir(
229
+ base_dir=self.config.output_dir,
230
+ model_name=model_name,
231
+ classifier_config=classifier_config,
232
+ dataset_name=dataset_name,
233
+ )
234
+
235
+ # Extract shortnames for logging
236
+ model_short = extract_model_shortname(model_name)
237
+ classifier_short = extract_classifier_shortname(classifier_config)
238
+
239
+ self.logger.info(
240
+ f"Using hierarchical output: {hierarchical_path} "
241
+ f"(model={model_short}, classifier={classifier_short})"
242
+ )
243
+ return str(hierarchical_path)
244
+
245
+ except (json.JSONDecodeError, KeyError, ValueError) as e:
246
+ self.logger.warning(
247
+ f"Could not determine hierarchical path from model config: {e}. "
248
+ "Using flat structure."
249
+ )
250
+ return self.config.output_dir
251
+
252
+ def run_comprehensive_test(self) -> dict[str, dict[str, Any]]:
253
+ """Run testing pipeline"""
254
+ self.logger.info("Starting model testing")
255
+ self.logger.info(f"Models to test: {self.config.model_paths}")
256
+ self.logger.info(f"Datasets to test: {self.config.data_paths}")
257
+
258
+ all_results = {}
259
+ failed_datasets = []
260
+ failed_models = []
261
+
262
+ try:
263
+ # Test each dataset
264
+ for data_path in self.config.data_paths:
265
+ dataset_name = Path(data_path).stem
266
+ self.logger.info(f"\n{'=' * 60}")
267
+ self.logger.info(f"Testing on dataset: {dataset_name}")
268
+ self.logger.info(f"{'=' * 60}")
269
+
270
+ # Load dataset
271
+ try:
272
+ sequences, labels_list = load_dataset(data_path, self.config)
273
+ labels: np.ndarray = np.array(labels_list)
274
+ except Exception as e:
275
+ self.logger.error(f"Failed to load dataset {data_path}: {e}")
276
+ failed_datasets.append((dataset_name, str(e)))
277
+ continue
278
+
279
+ dataset_results = {}
280
+
281
+ # Test each model
282
+ for model_path in self.config.model_paths:
283
+ model_name = Path(model_path).stem
284
+ self.logger.info(f"\nTesting model: {model_name}")
285
+
286
+ output_dir_for_dataset = self._compute_output_directory(
287
+ model_path, dataset_name
288
+ )
289
+
290
+ try:
291
+ # Load model
292
+ model = self.load_model(model_path)
293
+
294
+ # Extract embeddings
295
+ X_embedded = self.embed_sequences(
296
+ sequences,
297
+ model,
298
+ f"{dataset_name}_{model_name}",
299
+ output_dir_for_dataset,
300
+ )
301
+
302
+ # Evaluation (delegated to evaluation module)
303
+ test_results = evaluate_pretrained(
304
+ model,
305
+ X_embedded,
306
+ labels,
307
+ model_name,
308
+ dataset_name,
309
+ self.config.metrics,
310
+ self.config.threshold,
311
+ )
312
+ dataset_results[model_name] = test_results
313
+
314
+ # Visualization (delegated to visualization module)
315
+ single_model_results = {model_name: test_results}
316
+ plot_confusion_matrix(
317
+ single_model_results,
318
+ dataset_name,
319
+ output_dir=output_dir_for_dataset,
320
+ )
321
+ save_detailed_results(
322
+ single_model_results,
323
+ dataset_name,
324
+ self.config.__dict__,
325
+ output_dir=output_dir_for_dataset,
326
+ save_predictions=self.config.save_predictions,
327
+ )
328
+
329
+ except Exception as e:
330
+ self.logger.error(f"Failed to test model {model_path}: {e}")
331
+ failed_models.append((f"{dataset_name}_{model_name}", str(e)))
332
+ continue
333
+
334
+ # Generate aggregated multi-model report
335
+ if dataset_results:
336
+ aggregated_output_dir = self.config.output_dir
337
+ self.logger.info(
338
+ f"Generating aggregated multi-model report for {dataset_name} "
339
+ f"in {aggregated_output_dir}"
340
+ )
341
+
342
+ plot_confusion_matrix(
343
+ dataset_results,
344
+ dataset_name,
345
+ output_dir=aggregated_output_dir,
346
+ )
347
+ save_detailed_results(
348
+ dataset_results,
349
+ dataset_name,
350
+ self.config.__dict__,
351
+ output_dir=aggregated_output_dir,
352
+ save_predictions=self.config.save_predictions,
353
+ )
354
+
355
+ all_results[dataset_name] = dataset_results
356
+
357
+ # Check if all tests failed
358
+ if not all_results:
359
+ error_msg = "All tests failed:\n"
360
+ if failed_datasets:
361
+ error_msg += (
362
+ f" Failed datasets: {[name for name, _ in failed_datasets]}\n"
363
+ )
364
+ if failed_models:
365
+ error_msg += (
366
+ f" Failed models: {[name for name, _ in failed_models]}\n"
367
+ )
368
+ raise RuntimeError(error_msg + "No successful test results to report.")
369
+
370
+ if failed_datasets or failed_models:
371
+ self.logger.warning(
372
+ f"\nSome tests failed (datasets: {len(failed_datasets)}, "
373
+ f"models: {len(failed_models)}). Check logs for details."
374
+ )
375
+
376
+ self.results = all_results
377
+ self.logger.info(
378
+ f"\nTesting completed. Results saved to: {self.config.output_dir}"
379
+ )
380
+
381
+ finally:
382
+ self.cleanup_cached_embeddings()
383
+
384
+ return all_results
src/antibody_training_esm/cli/testing/visualization.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Plotting and result serialization utilities."""
2
+
3
+ import logging
4
+ import os
5
+ from datetime import datetime
6
+ from typing import Any
7
+
8
+ import matplotlib.pyplot as plt
9
+ import pandas as pd
10
+ import seaborn as sns
11
+ import yaml
12
+
13
+ # Configure matplotlib
14
+ plt.style.use("seaborn-v0_8" if "seaborn-v0_8" in plt.style.available else "default")
15
+ sns.set_palette("husl")
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def plot_confusion_matrix(
21
+ results: dict[str, dict[str, Any]],
22
+ dataset_name: str,
23
+ output_dir: str,
24
+ ) -> None:
25
+ """
26
+ Create confusion matrix visualization (individual files per model).
27
+
28
+ Args:
29
+ results: Dictionary mapping model names to result dictionaries.
30
+ dataset_name: Name of the dataset.
31
+ output_dir: Directory to save plots.
32
+ """
33
+ os.makedirs(output_dir, exist_ok=True)
34
+
35
+ logger.info(f"Creating confusion matrices for {dataset_name} in {output_dir}")
36
+
37
+ # Create individual confusion matrix for each model to prevent overrides
38
+ for model_name, model_results in results.items():
39
+ if "confusion_matrix" not in model_results:
40
+ logger.warning(f"No confusion matrix found for {model_name}, skipping plot")
41
+ continue
42
+
43
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
44
+ cm = model_results["confusion_matrix"]
45
+ sns.heatmap(
46
+ cm,
47
+ annot=True,
48
+ fmt="d",
49
+ cmap="Blues",
50
+ xticklabels=["Negative", "Positive"],
51
+ yticklabels=["Negative", "Positive"],
52
+ ax=ax,
53
+ )
54
+ ax.set_title(f"Confusion Matrix - {model_name} on {dataset_name}")
55
+ ax.set_ylabel("True Label")
56
+ ax.set_xlabel("Predicted Label")
57
+
58
+ plt.tight_layout()
59
+
60
+ # Save plot with model name to prevent overrides when testing multiple backbones
61
+ plot_file = os.path.join(
62
+ output_dir,
63
+ f"confusion_matrix_{model_name}_{dataset_name}.png",
64
+ )
65
+ plt.savefig(plot_file, dpi=300, bbox_inches="tight")
66
+ plt.close()
67
+
68
+ logger.info(f"Confusion matrix saved to {plot_file}")
69
+
70
+
71
+ def save_detailed_results(
72
+ results: dict[str, dict[str, Any]],
73
+ dataset_name: str,
74
+ config_dict: dict[str, Any],
75
+ output_dir: str,
76
+ save_predictions: bool = True,
77
+ ) -> None:
78
+ """
79
+ Save detailed results to files (individual files per model).
80
+
81
+ Args:
82
+ results: Dictionary mapping model names to result dictionaries.
83
+ dataset_name: Name of the dataset.
84
+ config_dict: Configuration dictionary to embed in YAML.
85
+ output_dir: Directory to save results.
86
+ save_predictions: Whether to save prediction CSVs.
87
+ """
88
+ os.makedirs(output_dir, exist_ok=True)
89
+
90
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
91
+
92
+ # Save individual YAML for each model to prevent overrides
93
+ for model_name, model_results in results.items():
94
+ results_file = os.path.join(
95
+ output_dir,
96
+ f"detailed_results_{model_name}_{dataset_name}_{timestamp}.yaml",
97
+ )
98
+ with open(results_file, "w") as f:
99
+ yaml.dump(
100
+ {
101
+ "dataset": dataset_name,
102
+ "model": model_name,
103
+ "config": config_dict,
104
+ "results": model_results,
105
+ },
106
+ f,
107
+ default_flow_style=False,
108
+ )
109
+ logger.info(f"Detailed results saved to {results_file}")
110
+
111
+ # Save predictions if requested
112
+ if save_predictions:
113
+ for model_name, model_results in results.items():
114
+ if "predictions" in model_results:
115
+ pred_file = os.path.join(
116
+ output_dir,
117
+ f"predictions_{model_name}_{dataset_name}_{timestamp}.csv",
118
+ )
119
+ pred_df = pd.DataFrame(
120
+ {
121
+ "y_true": model_results["predictions"]["y_true"],
122
+ "y_pred": model_results["predictions"]["y_pred"],
123
+ "y_proba": model_results["predictions"]["y_proba"],
124
+ }
125
+ )
126
+ pred_df.to_csv(pred_file, index=False)
127
+ logger.info(f"Predictions saved to {pred_file}")
src/antibody_training_esm/cli/train.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training CLI - Hydra Entry Point
3
+
4
+ Professional command-line interface for antibody model training.
5
+ Uses Hydra for configuration management and supports dynamic overrides.
6
+
7
+ Usage:
8
+ # Default config
9
+ antibody-train
10
+
11
+ # With overrides
12
+ antibody-train hardware.device=cuda training.batch_size=16
13
+
14
+ # Multi-run sweep
15
+ antibody-train --multirun classifier.C=0.1,1.0,10.0
16
+
17
+ # Help
18
+ antibody-train --help
19
+ """
20
+
21
+ from antibody_training_esm.core.trainer import main as hydra_main
22
+
23
+
24
+ def main() -> None:
25
+ """
26
+ Main entry point for training CLI
27
+
28
+ Delegates to Hydra-decorated main() in core.trainer.
29
+ This provides automatic config composition, override support,
30
+ and multi-run sweeps.
31
+
32
+ Note:
33
+ This function does not return an exit code (Hydra handles that).
34
+ Use try/except at a higher level if you need custom error handling.
35
+ """
36
+ # Delegate to Hydra entry point
37
+ # Hydra automatically parses sys.argv and handles all CLI logic
38
+ hydra_main()
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()
src/antibody_training_esm/conf/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hydra configuration package
3
+
4
+ Contains YAML configs and structured config schemas.
5
+ """
6
+
7
+ # Import config_schema to execute ConfigStore registrations
8
+ # This MUST run at import time for structured configs to work
9
+ from . import config_schema # noqa: F401
src/antibody_training_esm/conf/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (356 Bytes). View file
 
src/antibody_training_esm/conf/__pycache__/config_schema.cpython-312.pyc ADDED
Binary file (5.3 kB). View file
 
src/antibody_training_esm/conf/classifier/logreg.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type: logistic_regression
2
+ C: 1.0
3
+ penalty: l2
4
+ solver: lbfgs
5
+ max_iter: 1000
6
+ random_state: ${training.random_state}
7
+ class_weight: null
8
+ cv_folds: 10
9
+ stratify: true
10
+ path: null
11
+ # Optional path to the JSON config file (for .npz models)
12
+ config_path: null
src/antibody_training_esm/conf/classifier/xgboost.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type: xgboost
2
+ n_estimators: 100
3
+ max_depth: 6
4
+ learning_rate: 0.3
5
+ subsample: 1.0
6
+ colsample_bytree: 1.0
7
+ reg_alpha: 0.0
8
+ reg_lambda: 1.0
9
+ random_state: ${training.random_state}
10
+ objective: binary:logistic
11
+ cv_folds: 10
12
+ stratify: true
13
+ path: null
14
+ config_path: null
src/antibody_training_esm/conf/config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model: esm1v
3
+ - classifier: logreg
4
+ - data: boughter_jain
5
+ - hardware: default
6
+ - hydra: default
7
+ - _self_
8
+
9
+ # Training settings (matches current trainer.py requirements)
10
+ training:
11
+ # Cross-validation
12
+ n_splits: 10
13
+ random_state: 42
14
+ stratify: true
15
+
16
+ # Evaluation metrics (list of metrics to compute)
17
+ metrics: [accuracy, precision, recall, f1, roc_auc]
18
+
19
+ # Model saving
20
+ save_model: true
21
+ model_name: boughter_vh_esm1v_logreg
22
+ model_save_dir: ./experiments/checkpoints
23
+
24
+ # Logging (Hydra-aware: relative to Hydra output dir, or logs/ in legacy mode)
25
+ log_level: INFO
26
+ log_file: training.log
27
+
28
+ # Performance optimization
29
+ batch_size: 8
30
+ num_workers: 4
31
+
32
+ # Experiment metadata (Hydra manages output dirs)
33
+ experiment:
34
+ name: novo_replication
35
+ description: "Train ESM-1v VH-based LogisticReg on Boughter, test on Jain"
36
+ tags: [baseline, esm1v, logreg]
src/antibody_training_esm/conf/config_schema.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Structured configuration schemas for Hydra
3
+
4
+ Type-safe configuration using dataclasses with full field coverage
5
+ validated against current trainer.py requirements.
6
+ """
7
+
8
+ from dataclasses import dataclass, field
9
+
10
+ # ConfigStore import removed - no longer needed since registrations are commented out
11
+ # from hydra.core.config_store import ConfigStore
12
+ from omegaconf import MISSING
13
+
14
+
15
+ @dataclass
16
+ class ModelConfig:
17
+ """ESM model configuration (matches current model config structure)"""
18
+
19
+ name: str = "facebook/esm1v_t33_650M_UR90S_1"
20
+ revision: str = "main"
21
+ device: str = MISSING # Provided by YAML interpolation ${hardware.device}
22
+
23
+
24
+ @dataclass
25
+ class ClassifierConfig:
26
+ """Classifier head configuration (matches current classifier config)"""
27
+
28
+ type: str = "logistic_regression"
29
+ C: float = 1.0
30
+ penalty: str = "l2"
31
+ solver: str = "lbfgs"
32
+ max_iter: int = 1000
33
+ random_state: int = (
34
+ MISSING # Provided by YAML interpolation ${training.random_state}
35
+ )
36
+ class_weight: str | None = None
37
+ cv_folds: int = 10
38
+ stratify: bool = True
39
+
40
+
41
+ @dataclass
42
+ class DataConfig:
43
+ """Dataset configuration (ALL fields used by loaders.py + trainer.py)"""
44
+
45
+ # REQUIRED by loaders.py
46
+ source: str = "local"
47
+ train_file: str = MISSING # Required
48
+ test_file: str = MISSING # Required
49
+ sequence_column: str = "sequence"
50
+ label_column: str = "label"
51
+
52
+ # REQUIRED by trainer.py
53
+ embeddings_cache_dir: str = "./experiments/cache"
54
+
55
+ # Optional fields
56
+ dataset_name: str = "boughter_vh"
57
+ max_sequence_length: int = 1024
58
+ save_embeddings: bool = True
59
+
60
+ # Fragment metadata (testing only)
61
+ train_fragment: str = "VH"
62
+ test_fragment: str = "VH"
63
+ test_assay: str = "ELISA"
64
+ test_threshold: float = 0.5
65
+
66
+
67
+ @dataclass
68
+ class TrainingConfig:
69
+ """Training hyperparameters (ALL fields used by trainer.py)"""
70
+
71
+ # Cross-validation
72
+ n_splits: int = 10
73
+ random_state: int = 42
74
+ stratify: bool = True
75
+
76
+ # Evaluation metrics
77
+ metrics: list[str] = field(
78
+ default_factory=lambda: ["accuracy", "precision", "recall", "f1", "roc_auc"]
79
+ )
80
+
81
+ # Model saving
82
+ save_model: bool = True
83
+ model_name: str = "boughter_vh_esm1v_logreg"
84
+ model_save_dir: str = "./experiments/checkpoints"
85
+
86
+ # Logging (Hydra-aware: relative to Hydra output dir, or logs/ in legacy mode)
87
+ log_level: str = "INFO"
88
+ log_file: str = "training.log" # Routes to logs/ dir in legacy mode, Hydra output dir in Hydra mode
89
+
90
+ # Performance optimization
91
+ batch_size: int = 8
92
+ num_workers: int = 4
93
+
94
+
95
+ @dataclass
96
+ class HardwareConfig:
97
+ """Hardware settings"""
98
+
99
+ device: str = "mps"
100
+ gpu_memory_fraction: float = 0.8
101
+ clear_cache_frequency: int = 100
102
+
103
+
104
+ @dataclass
105
+ class ExperimentConfig:
106
+ """Experiment metadata"""
107
+
108
+ name: str = "novo_replication"
109
+ description: str = "Train ESM-1v VH-based LogisticReg on Boughter, test on Jain"
110
+ tags: list[str] = field(default_factory=lambda: ["baseline", "esm1v", "logreg"])
111
+
112
+
113
+ @dataclass
114
+ class Config:
115
+ """Root configuration (complete schema matching current trainer.py)"""
116
+
117
+ model: ModelConfig = field(default_factory=ModelConfig)
118
+ classifier: ClassifierConfig = field(default_factory=ClassifierConfig)
119
+ data: DataConfig = field(default_factory=DataConfig)
120
+ training: TrainingConfig = field(default_factory=TrainingConfig)
121
+ hardware: HardwareConfig = field(default_factory=HardwareConfig)
122
+ experiment: ExperimentConfig = field(default_factory=ExperimentConfig)
123
+
124
+
125
+ # ConfigStore registrations REMOVED to fix CLI override bug
126
+ #
127
+ # Root cause: Registering structured configs with the same names as YAML files
128
+ # causes Hydra to prefer ConfigStore over YAML when using package-based config
129
+ # loading (which the console script does). This breaks config group overrides.
130
+ #
131
+ # Known issue: Hydra structured configs strictly validate keys.
132
+ # Overrides adding new keys require proper schema definition or +key syntax with strict mode disabled.# See: https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching
133
+ #
134
+ # The dataclasses above are kept for type hints and validation in code, but are
135
+ # no longer registered with ConfigStore. This allows YAML files to be the single
136
+ # source of truth for configuration.
137
+ #
138
+ # cs = ConfigStore.instance()
139
+ # cs.store(name="config", node=Config)
140
+ # cs.store(group="model", name="esm1v", node=ModelConfig)
141
+ # cs.store(group="classifier", name="logreg", node=ClassifierConfig)
142
+ # cs.store(group="data", name="boughter_jain", node=DataConfig)
src/antibody_training_esm/conf/data/boughter_jain.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data source (matches current loaders.py requirements)
2
+ source: local
3
+ dataset_name: boughter_vh
4
+
5
+ # File paths
6
+ train_file: data/train/boughter/canonical/VH_only_boughter_training.csv
7
+ test_file: data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv
8
+
9
+ # Data format options (required by loaders.py)
10
+ # Jain canonical parity file uses 'vh_sequence'; align config to avoid column errors
11
+ sequence_column: sequence
12
+ label_column: label
13
+ max_sequence_length: 1024
14
+
15
+ # Embedding caching (required by trainer.py)
16
+ save_embeddings: true
17
+ embeddings_cache_dir: ./experiments/cache
18
+
19
+ # Fragment metadata (for testing only)
20
+ train_fragment: VH
21
+ test_fragment: VH
22
+ test_assay: ELISA
23
+ test_threshold: 0.5
src/antibody_training_esm/conf/hardware/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Hardware configuration
2
+ # Default to MPS for macOS performance (training/testing); Gradio app handles stability fallback
3
+ device: mps
4
+ gpu_memory_fraction: 0.8
5
+ clear_cache_frequency: 100
src/antibody_training_esm/conf/hydra/default.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hydra output directory management
2
+ run:
3
+ dir: experiments/runs/${experiment.name}/${now:%Y-%m-%d_%H-%M-%S}
4
+
5
+ sweep:
6
+ dir: experiments/runs/sweeps/${experiment.name}
7
+ subdir: ${hydra.job.num}
8
+
9
+ job:
10
+ chdir: false # Don't change working directory
src/antibody_training_esm/conf/model/esm1v.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: facebook/esm1v_t33_650M_UR90S_1
2
+ revision: main
3
+ # Default to CPU for stability on macOS; override with hardware.device or CLI if desired
4
+ device: ${hardware.device}
src/antibody_training_esm/conf/model/esm2_650m.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: facebook/esm2_t33_650M_UR50D
2
+ revision: main
3
+ device: ${hardware.device}
src/antibody_training_esm/conf/predict.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - /model: esm1v
5
+ - /classifier: logreg
6
+ - /hardware: default
7
+ - _self_
8
+
9
+ input_file: null
10
+ output_file: "predictions.csv"
11
+ sequence_column: "sequence"
12
+ assay_type: null # Options: "PSR", "ELISA", or null
13
+ threshold: 0.5 # Ignored if assay_type is set
14
+
15
+ gradio:
16
+ server_name: "0.0.0.0"
17
+ server_port: 7860
18
+ share: false
19
+ queue:
20
+ concurrency_limit: 2 # Based on 8GB VRAM (3GB per ESM-1v inference)
21
+ max_size: 10 # Prevents unbounded queue growth
22
+ log_level: INFO
23
+
24
+ hydra:
25
+ job:
26
+ chdir: False
src/antibody_training_esm/conf/testing/jain_p5e_s2.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ model_paths: [experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl]
2
+ data_paths: [data/test/jain/canonical/VH_only_jain_86_p5e_s2.csv]
3
+ sequence_column: vh_sequence
4
+ label_column: label
5
+ output_dir: experiments/benchmarks
6
+ device: cpu
7
+ batch_size: 8
src/antibody_training_esm/core/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core ML Module
3
+
4
+ Professional ML components for antibody classification:
5
+ - ESM embedding extraction
6
+ - Binary classification
7
+ - Training pipelines
8
+ - Model serialization (pickle + NPZ+JSON)
9
+ """
10
+
11
+ from antibody_training_esm.core.classifier import BinaryClassifier
12
+ from antibody_training_esm.core.embeddings import ESMEmbeddingExtractor
13
+ from antibody_training_esm.core.trainer import load_model_from_npz
14
+
15
+ __all__ = [
16
+ "BinaryClassifier",
17
+ "ESMEmbeddingExtractor",
18
+ "load_model_from_npz",
19
+ ]
src/antibody_training_esm/core/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (705 Bytes). View file
 
src/antibody_training_esm/core/__pycache__/classifier.cpython-312.pyc ADDED
Binary file (14.2 kB). View file
 
src/antibody_training_esm/core/__pycache__/classifier_factory.cpython-312.pyc ADDED
Binary file (4.66 kB). View file