lighteternal commited on
Commit
57f1553
·
verified ·
1 Parent(s): 6fb0bbf

Upload BioAssayAlign compatibility Space bundle

Browse files
Files changed (5) hide show
  1. README.md +70 -5
  2. app.py +407 -0
  3. examples/btk_candidates.csv +6 -0
  4. requirements.txt +8 -0
  5. space_runtime.py +554 -0
README.md CHANGED
@@ -1,12 +1,77 @@
1
  ---
2
  title: BioAssayAlign Compatibility Explorer
3
- emoji: 🦀
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: BioAssayAlign Compatibility Explorer
3
+ emoji: 🧪
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
+ license: mit
10
+ short_description: Rank candidate molecules for a bioassay.
11
  ---
12
 
13
+ # BioAssayAlign Compatibility Explorer
14
+
15
+ This Space is a scientist-facing demo for **assay-conditioned compound ranking**.
16
+
17
+ You provide:
18
+ - a bioassay description and optional metadata
19
+ - a list of candidate SMILES
20
+
21
+ The model returns:
22
+ - a ranked list of molecules
23
+ - a compatibility score for each one
24
+ - explicit flags for invalid SMILES
25
+
26
+ ## What It Is
27
+
28
+ This is not a chatbot and it is not a potency predictor.
29
+
30
+ It is a **ranking model** trained on a frozen public bioassay dataset built from PubChem BioAssay and ChEMBL. It is designed to answer:
31
+
32
+ > “Given this assay, which molecules should I look at first?”
33
+
34
+ ## What The Score Means
35
+
36
+ - Higher score = the model believes the molecule is more compatible with the assay than lower-ranked candidates in the same list.
37
+ - The score is **not** a probability.
38
+ - The score is best used for **ranking**, not absolute decision thresholds.
39
+
40
+ ## Recommended Input Style
41
+
42
+ The model is most reliable when assay information is provided as structured fields:
43
+ - title
44
+ - description
45
+ - organism
46
+ - readout
47
+ - assay format
48
+ - assay type
49
+ - target UniProt IDs
50
+
51
+ You can paste SMILES directly or upload a CSV with a `smiles` or `canonical_smiles` column.
52
+
53
+ ## Good Uses
54
+
55
+ - ranking a screening shortlist for a new assay concept
56
+ - triaging compounds before a more expensive downstream model or wet-lab step
57
+ - testing how sensitive rankings are to assay wording and metadata
58
+
59
+ ## Limits
60
+
61
+ - This is a public-data model, not a medicinal chemistry oracle.
62
+ - It does not predict IC50 directly.
63
+ - It is strongest as a **relative ranking tool** over a candidate list you already care about.
64
+
65
+ ## Runtime Notes
66
+
67
+ - The first request can be slower because the Space has to load the model.
68
+ - Large candidate lists increase runtime. For interactive use, start with a few hundred molecules.
69
+
70
+ ## Model
71
+
72
+ The Space reads the model repo from the `MODEL_REPO_ID` environment variable.
73
+
74
+ Default:
75
+ - `lighteternal/BioAssayAlign-Qwen3-Embedding-0.6B-Compatibility`
76
+
77
+ If the champion changes later, the Space can point to a new model repo without changing the UI.
app.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import os
5
+ import tempfile
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import gradio as gr
10
+ import pandas as pd
11
+
12
+ from space_runtime import AssayQuery, load_compatibility_model_from_hub, rank_compounds, serialize_assay_query
13
+
14
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "lighteternal/BioAssayAlign-Qwen3-Embedding-0.6B-Compatibility")
15
+ MAX_INPUT_SMILES = int(os.getenv("MAX_INPUT_SMILES", "3000"))
16
+ DEFAULT_TOP_K = int(os.getenv("DEFAULT_TOP_K", "50"))
17
+
18
+ CSS = """
19
+ @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:wght@400;500;600;700&family=IBM+Plex+Mono:wght@400;500&family=Source+Serif+4:wght@500;600;700&display=swap');
20
+
21
+ :root {
22
+ --paper: #f4efe6;
23
+ --ink: #122033;
24
+ --ink-soft: #4f6073;
25
+ --accent: #0f5fd7;
26
+ --accent-soft: #d9e8ff;
27
+ --line: #c9d1db;
28
+ --warning: #8a4b0f;
29
+ --good: #0e6b48;
30
+ }
31
+
32
+ .gradio-container {
33
+ font-family: "IBM Plex Sans", sans-serif;
34
+ background:
35
+ radial-gradient(circle at top right, rgba(15,95,215,0.08), transparent 24rem),
36
+ linear-gradient(180deg, #faf7f0 0%, var(--paper) 100%);
37
+ color: var(--ink);
38
+ }
39
+
40
+ #hero {
41
+ border: 1px solid var(--line);
42
+ background: linear-gradient(135deg, rgba(255,255,255,0.9), rgba(239,245,255,0.92));
43
+ border-radius: 24px;
44
+ padding: 1.25rem 1.4rem;
45
+ box-shadow: 0 20px 40px rgba(18,32,51,0.08);
46
+ }
47
+
48
+ .eyebrow {
49
+ font-family: "IBM Plex Mono", monospace;
50
+ font-size: 0.78rem;
51
+ letter-spacing: 0.08em;
52
+ text-transform: uppercase;
53
+ color: var(--accent);
54
+ }
55
+
56
+ .hero-title {
57
+ font-family: "Source Serif 4", serif;
58
+ font-size: 2.2rem;
59
+ line-height: 1.05;
60
+ margin: 0.2rem 0 0.5rem 0;
61
+ }
62
+
63
+ .hero-copy {
64
+ color: var(--ink-soft);
65
+ max-width: 60rem;
66
+ font-size: 1rem;
67
+ }
68
+
69
+ .panel-note {
70
+ border-left: 4px solid var(--accent);
71
+ background: rgba(15,95,215,0.06);
72
+ padding: 0.9rem 1rem;
73
+ border-radius: 12px;
74
+ }
75
+
76
+ .metric-strip {
77
+ display: grid;
78
+ grid-template-columns: repeat(3, minmax(0, 1fr));
79
+ gap: 0.8rem;
80
+ }
81
+
82
+ .metric-card {
83
+ border: 1px solid var(--line);
84
+ background: rgba(255,255,255,0.75);
85
+ padding: 0.8rem 0.9rem;
86
+ border-radius: 16px;
87
+ }
88
+
89
+ .metric-card strong {
90
+ display: block;
91
+ font-size: 1.15rem;
92
+ margin-top: 0.15rem;
93
+ }
94
+ """
95
+
96
+ EXAMPLES = {
97
+ "BTK binding": {
98
+ "title": "BTK kinase inhibitor binding assay",
99
+ "description": "In vitro kinase-domain binding assay for Bruton's tyrosine kinase inhibitor ranking.",
100
+ "organism": "Homo sapiens",
101
+ "readout": "binding",
102
+ "assay_format": "biochemical",
103
+ "assay_type": "binding",
104
+ "target_uniprot": "Q06187",
105
+ "smiles": "\n".join(
106
+ [
107
+ "CC1=NC(=O)N(C)C(=O)N1",
108
+ "CCOc1ccc2nc(N3CCN(C)CC3)n(C)c(=O)c2c1",
109
+ "CC(=O)Nc1ncc(C#N)c(Nc2ccc(F)c(Cl)c2)n1",
110
+ "c1ccccc1",
111
+ "CCO",
112
+ ]
113
+ ),
114
+ },
115
+ "ALDH1A1 fluorescence": {
116
+ "title": "ALDH1A1 inhibition assay",
117
+ "description": "Cell-based fluorescence assay measuring ALDH1A1 inhibition in human cells.",
118
+ "organism": "Homo sapiens",
119
+ "readout": "fluorescence",
120
+ "assay_format": "cell-based",
121
+ "assay_type": "inhibition",
122
+ "target_uniprot": "P00352",
123
+ "smiles": "\n".join(
124
+ [
125
+ "CC1=CC(=O)N(C)C(=O)N1",
126
+ "COC1=CC=C(C=C1)C(=O)O",
127
+ "CCN(CC)CCOC1=CC=CC=C1",
128
+ "CCOC1=CC=CC=C1",
129
+ "CCO",
130
+ ]
131
+ ),
132
+ },
133
+ }
134
+
135
+
136
+ def _parse_smiles_text(value: str | None) -> list[str]:
137
+ if not value:
138
+ return []
139
+ lines = [line.strip() for line in value.replace(",", "\n").splitlines()]
140
+ return [line for line in lines if line]
141
+
142
+
143
+ def _read_uploaded_smiles(file_obj: Any) -> list[str]:
144
+ if file_obj is None:
145
+ return []
146
+ path = Path(file_obj.name if hasattr(file_obj, "name") else str(file_obj))
147
+ suffix = path.suffix.lower()
148
+ if suffix in {".txt", ".smi", ".smiles"}:
149
+ return [line.strip() for line in path.read_text().splitlines() if line.strip()]
150
+ if suffix == ".csv":
151
+ frame = pd.read_csv(path)
152
+ for column in ("smiles", "canonical_smiles", "SMILES"):
153
+ if column in frame.columns:
154
+ return [str(item).strip() for item in frame[column].tolist() if str(item).strip()]
155
+ first = frame.columns[0]
156
+ return [str(item).strip() for item in frame[first].tolist() if str(item).strip()]
157
+ raise gr.Error("Upload a .csv, .txt, .smi, or .smiles file.")
158
+
159
+
160
+ def _collect_smiles(smiles_text: str, upload_file: Any) -> tuple[list[str], str | None]:
161
+ items = _parse_smiles_text(smiles_text) + _read_uploaded_smiles(upload_file)
162
+ deduped: list[str] = []
163
+ seen: set[str] = set()
164
+ for item in items:
165
+ if item not in seen:
166
+ deduped.append(item)
167
+ seen.add(item)
168
+ warning = None
169
+ if len(deduped) > MAX_INPUT_SMILES:
170
+ warning = f"Input truncated to the first {MAX_INPUT_SMILES} unique SMILES for interactive use."
171
+ deduped = deduped[:MAX_INPUT_SMILES]
172
+ return deduped, warning
173
+
174
+
175
+ def _load_model():
176
+ return load_compatibility_model_from_hub(MODEL_REPO_ID)
177
+
178
+
179
+ def _build_summary(query_text: str, valid_rows: list[dict[str, Any]], invalid_rows: list[dict[str, Any]], warning: str | None) -> str:
180
+ best = valid_rows[0] if valid_rows else None
181
+ chunks = [
182
+ "### Run Summary",
183
+ f"- Model repo: `{MODEL_REPO_ID}`",
184
+ f"- Assay prompt length: `{len(query_text.split())}` tokens-equivalent words",
185
+ f"- Valid molecules ranked: `{len(valid_rows)}`",
186
+ f"- Invalid molecules rejected: `{len(invalid_rows)}`",
187
+ ]
188
+ if best is not None:
189
+ chunks.append(f"- Top hit: `{best['canonical_smiles']}` with score `{best['score']:.3f}`")
190
+ if warning:
191
+ chunks.append(f"- Warning: {warning}")
192
+ chunks.append("")
193
+ chunks.append("Higher scores mean the model ranks the molecule as more compatible with this assay than lower-scored candidates in the same list. Scores are ranking signals, not calibrated probabilities.")
194
+ return "\n".join(chunks)
195
+
196
+
197
+ def _results_to_csv(valid_rows: list[dict[str, Any]], invalid_rows: list[dict[str, Any]]) -> str | None:
198
+ rows = valid_rows + invalid_rows
199
+ if not rows:
200
+ return None
201
+ handle = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="")
202
+ writer = csv.DictWriter(handle, fieldnames=["rank", "input_smiles", "canonical_smiles", "smiles_hash", "score", "valid", "error"])
203
+ writer.writeheader()
204
+ rank = 1
205
+ for row in valid_rows:
206
+ writer.writerow(
207
+ {
208
+ "rank": rank,
209
+ "input_smiles": row["input_smiles"],
210
+ "canonical_smiles": row["canonical_smiles"],
211
+ "smiles_hash": row["smiles_hash"],
212
+ "score": row["score"],
213
+ "valid": True,
214
+ "error": "",
215
+ }
216
+ )
217
+ rank += 1
218
+ for row in invalid_rows:
219
+ writer.writerow(
220
+ {
221
+ "rank": "",
222
+ "input_smiles": row["input_smiles"],
223
+ "canonical_smiles": "",
224
+ "smiles_hash": "",
225
+ "score": "",
226
+ "valid": False,
227
+ "error": row.get("error", "invalid_smiles"),
228
+ }
229
+ )
230
+ handle.close()
231
+ return handle.name
232
+
233
+
234
+ def run_ranking(
235
+ title: str,
236
+ description: str,
237
+ organism: str,
238
+ readout: str,
239
+ assay_format: str,
240
+ assay_type: str,
241
+ target_uniprot: str,
242
+ smiles_text: str,
243
+ upload_file: Any,
244
+ top_k: int,
245
+ ):
246
+ smiles_values, warning = _collect_smiles(smiles_text, upload_file)
247
+ if not smiles_values:
248
+ raise gr.Error("Provide at least one SMILES entry by paste or file upload.")
249
+ query = AssayQuery(
250
+ title=title or "",
251
+ description=description or "",
252
+ organism=organism or "",
253
+ readout=readout or "",
254
+ assay_format=assay_format or "",
255
+ assay_type=assay_type or "",
256
+ target_uniprot=[token.strip() for token in target_uniprot.split(",") if token.strip()],
257
+ )
258
+ assay_text = serialize_assay_query(query)
259
+ model = _load_model()
260
+ ranked = rank_compounds(model, assay_text=assay_text, smiles_list=smiles_values, top_k=top_k or None)
261
+ valid_rows = [row for row in ranked if row["valid"]]
262
+ invalid_rows = [row for row in ranked if not row["valid"]]
263
+
264
+ display_rows = [
265
+ {
266
+ "rank": idx + 1,
267
+ "input_smiles": row["input_smiles"],
268
+ "canonical_smiles": row["canonical_smiles"],
269
+ "smiles_hash": row["smiles_hash"],
270
+ "score": round(float(row["score"]), 4),
271
+ }
272
+ for idx, row in enumerate(valid_rows)
273
+ ]
274
+ invalid_display = [
275
+ {"input_smiles": row["input_smiles"], "error": row.get("error", "invalid_smiles")}
276
+ for row in invalid_rows
277
+ ]
278
+ summary = _build_summary(assay_text, valid_rows, invalid_rows, warning)
279
+ csv_path = _results_to_csv(valid_rows, invalid_rows)
280
+ return summary, assay_text, pd.DataFrame(display_rows), pd.DataFrame(invalid_display), csv_path
281
+
282
+
283
+ def load_example(example_name: str):
284
+ example = EXAMPLES[example_name]
285
+ return (
286
+ example["title"],
287
+ example["description"],
288
+ example["organism"],
289
+ example["readout"],
290
+ example["assay_format"],
291
+ example["assay_type"],
292
+ example["target_uniprot"],
293
+ example["smiles"],
294
+ )
295
+
296
+
297
+ with gr.Blocks(css=CSS, title="BioAssayAlign Compatibility Explorer") as demo:
298
+ gr.Markdown(
299
+ """
300
+ <div id="hero">
301
+ <div class="eyebrow">BioAssayAlign · scientist-facing ranking demo</div>
302
+ <div class="hero-title">Rank candidate molecules for a bioassay</div>
303
+ <div class="hero-copy">
304
+ Build an assay query from structured fields, paste or upload a candidate molecule list, and get a ranked output from the current BioAssayAlign compatibility model.
305
+ This app is designed for triage and prioritization, not for direct potency claims.
306
+ </div>
307
+ </div>
308
+ """
309
+ )
310
+
311
+ with gr.Row():
312
+ with gr.Column(scale=5):
313
+ gr.Markdown(
314
+ """
315
+ <div class="panel-note">
316
+ Use the structured fields if you have them. Missing fields are allowed, but species, readout, and target metadata usually help.
317
+ </div>
318
+ """
319
+ )
320
+ with gr.Column(scale=4):
321
+ gr.Markdown(
322
+ f"""
323
+ <div class="metric-strip">
324
+ <div class="metric-card"><span>Default model</span><strong>{MODEL_REPO_ID}</strong></div>
325
+ <div class="metric-card"><span>Expected use</span><strong>ranking, not probability</strong></div>
326
+ <div class="metric-card"><span>Interactive cap</span><strong>{MAX_INPUT_SMILES} SMILES</strong></div>
327
+ </div>
328
+ """
329
+ )
330
+
331
+ with gr.Tab("Rank Compounds"):
332
+ with gr.Row():
333
+ with gr.Column(scale=6):
334
+ example_name = gr.Dropdown(choices=list(EXAMPLES.keys()), value="BTK binding", label="Load an example")
335
+ load_example_btn = gr.Button("Load Example", variant="secondary")
336
+ assay_title = gr.Textbox(label="Assay title")
337
+ description = gr.Textbox(label="Description", lines=6, placeholder="Describe the assay in practical lab language.")
338
+ with gr.Row():
339
+ organism = gr.Textbox(label="Organism", placeholder="Homo sapiens")
340
+ readout = gr.Textbox(label="Readout", placeholder="binding / fluorescence / luminescence")
341
+ with gr.Row():
342
+ assay_format = gr.Textbox(label="Assay format", placeholder="biochemical / cell-based")
343
+ assay_type = gr.Textbox(label="Assay type", placeholder="binding / inhibition / activation")
344
+ target_uniprot = gr.Textbox(label="Target UniProt IDs", placeholder="Q06187, P52333")
345
+
346
+ with gr.Column(scale=5):
347
+ smiles_text = gr.Textbox(
348
+ label="Candidate SMILES",
349
+ lines=14,
350
+ placeholder="Paste one SMILES per line. CSV upload is optional and will be merged.",
351
+ )
352
+ upload_file = gr.File(label="Upload CSV / TXT / SMI", file_count="single", file_types=[".csv", ".txt", ".smi", ".smiles"])
353
+ top_k = gr.Slider(label="Top-K rows to display", minimum=5, maximum=200, step=5, value=DEFAULT_TOP_K)
354
+ run_btn = gr.Button("Rank Molecules", variant="primary")
355
+ clear_btn = gr.ClearButton(value="Clear", components=[assay_title, description, organism, readout, assay_format, assay_type, target_uniprot, smiles_text, upload_file])
356
+
357
+ summary = gr.Markdown()
358
+ with gr.Accordion("Serialized assay text used by the model", open=False):
359
+ assay_preview = gr.Textbox(lines=12, show_copy_button=True, label="Model-facing assay text")
360
+ ranked_df = gr.Dataframe(label="Ranked molecules", interactive=False, wrap=True)
361
+ invalid_df = gr.Dataframe(label="Rejected inputs", interactive=False, wrap=True)
362
+ download_file = gr.File(label="Download CSV")
363
+
364
+ load_example_btn.click(
365
+ load_example,
366
+ inputs=[example_name],
367
+ outputs=[assay_title, description, organism, readout, assay_format, assay_type, target_uniprot, smiles_text],
368
+ )
369
+ run_btn.click(
370
+ run_ranking,
371
+ inputs=[assay_title, description, organism, readout, assay_format, assay_type, target_uniprot, smiles_text, upload_file, top_k],
372
+ outputs=[summary, assay_preview, ranked_df, invalid_df, download_file],
373
+ )
374
+
375
+ with gr.Tab("How To Use This"):
376
+ gr.Markdown(
377
+ """
378
+ ### Recommended workflow
379
+
380
+ 1. Describe the assay in plain scientific language.
381
+ 2. Add metadata if you know it: organism, readout, format, assay type, target UniProt.
382
+ 3. Paste a candidate list or upload a CSV with a `smiles` column.
383
+ 4. Rank the list and inspect the top molecules first.
384
+
385
+ ### What the score means
386
+
387
+ - The score is a ranking signal.
388
+ - Higher means “more compatible than the other molecules in this submitted list”.
389
+ - It is **not** a calibrated activity probability and it is **not** an IC50 prediction.
390
+
391
+ ### Good input habits
392
+
393
+ - Prefer parent, neutralized, chemically sensible SMILES.
394
+ - Keep assay descriptions concrete.
395
+ - If the assay is target-defined, add the UniProt ID.
396
+
397
+ ### What this Space is not
398
+
399
+ - not a generative chemistry tool
400
+ - not a medicinal chemistry oracle
401
+ - not a wet-lab substitute
402
+ """
403
+ )
404
+
405
+
406
+ if __name__ == "__main__":
407
+ demo.queue(default_concurrency_limit=4).launch()
examples/btk_candidates.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ smiles
2
+ CC1=NC(=O)N(C)C(=O)N1
3
+ CCOc1ccc2nc(N3CCN(C)CC3)n(C)c(=O)c2c1
4
+ CC(=O)Nc1ncc(C#N)c(Nc2ccc(F)c(Cl)c2)n1
5
+ c1ccccc1
6
+ CCO
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.0,<6
2
+ huggingface_hub>=0.30
3
+ numpy<2
4
+ pandas>=2.2
5
+ rdkit-pypi>=2022.9.5
6
+ sentence-transformers>=5.2
7
+ torch>=2.2
8
+ transformers>=4.51
space_runtime.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ import re
6
+ from dataclasses import dataclass
7
+ from functools import lru_cache
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from huggingface_hub import snapshot_download
15
+ from rdkit import Chem, DataStructs, RDLogger
16
+ from rdkit.Chem import AllChem, Crippen, Descriptors, Lipinski, MACCSkeys, rdMolDescriptors
17
+ from rdkit.Chem.MolStandardize import rdMolStandardize
18
+ from sentence_transformers import SentenceTransformer
19
+ from torch import nn
20
+ from transformers import AutoModel, AutoTokenizer
21
+
22
+ RDLogger.DisableLog("rdApp.*")
23
+
24
+ DEFAULT_ASSAY_TASK = (
25
+ "Given a bioassay description and metadata, represent the assay for ranking compatible small molecules."
26
+ )
27
+ DEFAULT_DESCRIPTOR_NAMES = (
28
+ "mol_wt",
29
+ "logp",
30
+ "tpsa",
31
+ "heavy_atoms",
32
+ "hbd",
33
+ "hba",
34
+ "rot_bonds",
35
+ "ring_count",
36
+ "aromatic_rings",
37
+ "aliphatic_rings",
38
+ "saturated_rings",
39
+ "fraction_csp3",
40
+ "heteroatoms",
41
+ "amide_bonds",
42
+ "fragments",
43
+ "formal_charge",
44
+ "max_atomic_num",
45
+ "metal_atom_count",
46
+ "halogen_count",
47
+ "nitrogen_count",
48
+ "oxygen_count",
49
+ "sulfur_count",
50
+ "phosphorus_count",
51
+ "fluorine_count",
52
+ "chlorine_count",
53
+ "bromine_count",
54
+ "iodine_count",
55
+ "aromatic_atom_count",
56
+ "spiro_atoms",
57
+ "bridgehead_atoms",
58
+ )
59
+ ORGANIC_LIKE_ATOMIC_NUMBERS = {1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53}
60
+ SECTION_ORDER = [
61
+ "ASSAY_TITLE",
62
+ "DESCRIPTION",
63
+ "ORGANISM",
64
+ "READOUT",
65
+ "ASSAY_FORMAT",
66
+ "ASSAY_TYPE",
67
+ "TARGET_UNIPROT",
68
+ ]
69
+ ASSAY_SECTION_RE = re.compile(r"\[(ASSAY_TITLE|DESCRIPTION|ORGANISM|READOUT|ASSAY_FORMAT|ASSAY_TYPE|TARGET_UNIPROT)\]\n")
70
+ ORGANISM_ALIASES = {
71
+ "9606": "homo_sapiens",
72
+ "10090": "mus_musculus",
73
+ "10116": "rattus_norvegicus",
74
+ "4932": "saccharomyces_cerevisiae",
75
+ }
76
+
77
+
78
+ @dataclass
79
+ class AssayQuery:
80
+ title: str = ""
81
+ description: str = ""
82
+ organism: str = ""
83
+ readout: str = ""
84
+ assay_format: str = ""
85
+ assay_type: str = ""
86
+ target_uniprot: list[str] | None = None
87
+
88
+
89
+ def smiles_sha256(smiles: str) -> str:
90
+ return hashlib.sha256(smiles.encode("utf-8")).hexdigest()
91
+
92
+
93
+ @lru_cache(maxsize=1_000_000)
94
+ def _standardize_smiles_v2_cached(smiles: str) -> str | None:
95
+ mol = Chem.MolFromSmiles(smiles)
96
+ if mol is None:
97
+ return None
98
+ try:
99
+ mol = rdMolStandardize.Cleanup(mol)
100
+ mol = rdMolStandardize.FragmentParent(mol)
101
+ mol = rdMolStandardize.Uncharger().uncharge(mol)
102
+ mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol)
103
+ Chem.SanitizeMol(mol)
104
+ except Exception:
105
+ return None
106
+ if mol.GetNumHeavyAtoms() < 2:
107
+ return None
108
+ standardized = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True)
109
+ if not standardized or "." in standardized:
110
+ return None
111
+ return standardized
112
+
113
+
114
+ def standardize_smiles_v2(smiles: str | None) -> str | None:
115
+ if not smiles:
116
+ return None
117
+ token = smiles.strip()
118
+ if not token:
119
+ return None
120
+ return _standardize_smiles_v2_cached(token)
121
+
122
+
123
+ def serialize_assay_query(query: AssayQuery) -> str:
124
+ targets = ", ".join(query.target_uniprot or [])
125
+ values = {
126
+ "ASSAY_TITLE": query.title.strip(),
127
+ "DESCRIPTION": query.description.strip(),
128
+ "ORGANISM": query.organism.strip(),
129
+ "READOUT": query.readout.strip(),
130
+ "ASSAY_FORMAT": query.assay_format.strip(),
131
+ "ASSAY_TYPE": query.assay_type.strip(),
132
+ "TARGET_UNIPROT": targets.strip(),
133
+ }
134
+ return "\n\n".join(f"[{key}]\n{values[key]}" for key in SECTION_ORDER)
135
+
136
+
137
+ def _parse_assay_sections(assay_text: str) -> dict[str, str]:
138
+ sections = {key: "" for key in SECTION_ORDER}
139
+ parts = ASSAY_SECTION_RE.split(assay_text)
140
+ for idx in range(1, len(parts), 2):
141
+ key = parts[idx]
142
+ value = parts[idx + 1] if idx + 1 < len(parts) else ""
143
+ if key in sections:
144
+ sections[key] = value.strip()
145
+ return sections
146
+
147
+
148
+ def _hash_bucket(value: str, dim: int) -> int:
149
+ return abs(hash(value)) % max(dim, 1)
150
+
151
+
152
+ def _normalize_metadata_token(value: str) -> str:
153
+ return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_")
154
+
155
+
156
+ def _normalize_organism_token(value: str) -> str:
157
+ raw = value.strip()
158
+ if not raw:
159
+ return ""
160
+ aliased = ORGANISM_ALIASES.get(raw, raw)
161
+ return _normalize_metadata_token(aliased)
162
+
163
+
164
+ def _assay_metadata_vector(assay_text: str, *, dim: int) -> np.ndarray:
165
+ if dim <= 0:
166
+ return np.zeros((0,), dtype=np.float32)
167
+ sections = _parse_assay_sections(assay_text)
168
+ tokens: list[str] = []
169
+ organism = _normalize_organism_token(sections.get("ORGANISM", ""))
170
+ if organism:
171
+ tokens.append(f"organism:{organism}")
172
+ for key in ("READOUT", "ASSAY_FORMAT", "ASSAY_TYPE"):
173
+ value = _normalize_metadata_token(sections.get(key, ""))
174
+ if value:
175
+ tokens.append(f"{key.lower()}:{value}")
176
+ for target in sections.get("TARGET_UNIPROT", "").split(","):
177
+ token = target.strip().upper()
178
+ if token:
179
+ tokens.append(f"target:{token}")
180
+ vec = np.zeros((dim,), dtype=np.float32)
181
+ for token in tokens:
182
+ vec[_hash_bucket(token, dim)] += 1.0
183
+ norm = float(np.linalg.norm(vec))
184
+ if norm > 0:
185
+ vec /= norm
186
+ return vec
187
+
188
+
189
+ def _morgan_bits_from_mol(mol, *, radius: int, n_bits: int, use_chirality: bool) -> np.ndarray:
190
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits, useChirality=use_chirality)
191
+ arr = np.zeros((n_bits,), dtype=np.uint8)
192
+ DataStructs.ConvertToNumpyArray(fp, arr)
193
+ return arr
194
+
195
+
196
+ def _maccs_bits_from_mol(mol) -> np.ndarray:
197
+ fp = MACCSkeys.GenMACCSKeys(mol)
198
+ arr = np.zeros((fp.GetNumBits(),), dtype=np.uint8)
199
+ DataStructs.ConvertToNumpyArray(fp, arr)
200
+ return arr
201
+
202
+
203
+ def _count_atomic_nums(mol) -> dict[int, int]:
204
+ counts: dict[int, int] = {}
205
+ for atom in mol.GetAtoms():
206
+ atomic_num = int(atom.GetAtomicNum())
207
+ counts[atomic_num] = counts.get(atomic_num, 0) + 1
208
+ return counts
209
+
210
+
211
+ def _molecule_descriptor_vector(mol, *, names: tuple[str, ...] = DEFAULT_DESCRIPTOR_NAMES) -> np.ndarray:
212
+ counts = _count_atomic_nums(mol)
213
+ fragments = Chem.GetMolFrags(mol)
214
+ formal_charge = sum(int(atom.GetFormalCharge()) for atom in mol.GetAtoms())
215
+ max_atomic_num = max(counts) if counts else 0
216
+ metal_atom_count = sum(count for atomic_num, count in counts.items() if atomic_num not in ORGANIC_LIKE_ATOMIC_NUMBERS)
217
+ halogen_count = sum(counts.get(item, 0) for item in (9, 17, 35, 53))
218
+ aromatic_atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())
219
+ values = {
220
+ "mol_wt": float(Descriptors.MolWt(mol)),
221
+ "logp": float(Crippen.MolLogP(mol)),
222
+ "tpsa": float(rdMolDescriptors.CalcTPSA(mol)),
223
+ "heavy_atoms": float(mol.GetNumHeavyAtoms()),
224
+ "hbd": float(Lipinski.NumHDonors(mol)),
225
+ "hba": float(Lipinski.NumHAcceptors(mol)),
226
+ "rot_bonds": float(Lipinski.NumRotatableBonds(mol)),
227
+ "ring_count": float(rdMolDescriptors.CalcNumRings(mol)),
228
+ "aromatic_rings": float(rdMolDescriptors.CalcNumAromaticRings(mol)),
229
+ "aliphatic_rings": float(rdMolDescriptors.CalcNumAliphaticRings(mol)),
230
+ "saturated_rings": float(rdMolDescriptors.CalcNumSaturatedRings(mol)),
231
+ "fraction_csp3": float(rdMolDescriptors.CalcFractionCSP3(mol)),
232
+ "heteroatoms": float(rdMolDescriptors.CalcNumHeteroatoms(mol)),
233
+ "amide_bonds": float(rdMolDescriptors.CalcNumAmideBonds(mol)),
234
+ "fragments": float(len(fragments)),
235
+ "formal_charge": float(formal_charge),
236
+ "max_atomic_num": float(max_atomic_num),
237
+ "metal_atom_count": float(metal_atom_count),
238
+ "halogen_count": float(halogen_count),
239
+ "nitrogen_count": float(counts.get(7, 0)),
240
+ "oxygen_count": float(counts.get(8, 0)),
241
+ "sulfur_count": float(counts.get(16, 0)),
242
+ "phosphorus_count": float(counts.get(15, 0)),
243
+ "fluorine_count": float(counts.get(9, 0)),
244
+ "chlorine_count": float(counts.get(17, 0)),
245
+ "bromine_count": float(counts.get(35, 0)),
246
+ "iodine_count": float(counts.get(53, 0)),
247
+ "aromatic_atom_count": float(aromatic_atom_count),
248
+ "spiro_atoms": float(rdMolDescriptors.CalcNumSpiroAtoms(mol)),
249
+ "bridgehead_atoms": float(rdMolDescriptors.CalcNumBridgeheadAtoms(mol)),
250
+ }
251
+ return np.array([values[name] for name in names], dtype=np.float32)
252
+
253
+
254
+ class CompatibilityHead(nn.Module):
255
+ def __init__(self, *, assay_dim: int, molecule_dim: int, projection_dim: int, hidden_dim: int, dropout: float) -> None:
256
+ super().__init__()
257
+ self.assay_norm = nn.LayerNorm(assay_dim)
258
+ self.assay_proj = nn.Linear(assay_dim, projection_dim)
259
+ self.mol_norm = nn.LayerNorm(molecule_dim)
260
+ self.mol_proj = nn.Linear(molecule_dim, projection_dim, bias=False)
261
+ self.score_mlp = nn.Sequential(
262
+ nn.Linear(projection_dim * 4, hidden_dim),
263
+ nn.GELU(),
264
+ nn.Dropout(dropout),
265
+ nn.Linear(hidden_dim, 1),
266
+ )
267
+ self.dot_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
268
+
269
+ def encode_assay(self, assay_features: torch.Tensor) -> torch.Tensor:
270
+ vec = self.assay_proj(self.assay_norm(assay_features))
271
+ return F.normalize(vec, p=2, dim=-1)
272
+
273
+ def encode_molecule(self, molecule_features: torch.Tensor) -> torch.Tensor:
274
+ vec = self.mol_proj(self.mol_norm(molecule_features))
275
+ return F.normalize(vec, p=2, dim=-1)
276
+
277
+ def score_candidates(self, assay_features: torch.Tensor, candidate_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
278
+ assay_vec = self.encode_assay(assay_features)
279
+ mol_vec = self.encode_molecule(candidate_features)
280
+ assay_expand = assay_vec.unsqueeze(1).expand(-1, mol_vec.shape[1], -1)
281
+ dot_scores = (assay_expand * mol_vec).sum(dim=-1)
282
+ mlp_input = torch.cat(
283
+ [assay_expand, mol_vec, assay_expand * mol_vec, torch.abs(assay_expand - mol_vec)],
284
+ dim=-1,
285
+ )
286
+ mlp_scores = self.score_mlp(mlp_input).squeeze(-1)
287
+ logits = dot_scores * self.dot_scale + mlp_scores
288
+ return logits, assay_vec, mol_vec
289
+
290
+
291
+ class SpaceCompatibilityModel:
292
+ def __init__(
293
+ self,
294
+ *,
295
+ assay_encoder: SentenceTransformer,
296
+ compatibility_head: CompatibilityHead,
297
+ assay_task_description: str,
298
+ fingerprint_radii: tuple[int, ...],
299
+ fingerprint_bits: int,
300
+ use_chirality: bool,
301
+ use_maccs: bool,
302
+ use_rdkit_descriptors: bool,
303
+ descriptor_names: tuple[str, ...],
304
+ descriptor_mean: np.ndarray | None,
305
+ descriptor_std: np.ndarray | None,
306
+ molecule_transformer_model_name: str,
307
+ molecule_transformer_batch_size: int,
308
+ molecule_transformer_max_length: int,
309
+ use_assay_metadata_features: bool,
310
+ assay_metadata_dim: int,
311
+ ) -> None:
312
+ self.assay_encoder = assay_encoder
313
+ self.compatibility_head = compatibility_head.eval()
314
+ self.assay_task_description = assay_task_description
315
+ self.fingerprint_radii = fingerprint_radii
316
+ self.fingerprint_bits = fingerprint_bits
317
+ self.use_chirality = use_chirality
318
+ self.use_maccs = use_maccs
319
+ self.use_rdkit_descriptors = use_rdkit_descriptors
320
+ self.descriptor_names = descriptor_names
321
+ self.descriptor_mean = descriptor_mean
322
+ self.descriptor_std = descriptor_std
323
+ self.molecule_transformer_model_name = molecule_transformer_model_name
324
+ self.molecule_transformer_batch_size = molecule_transformer_batch_size
325
+ self.molecule_transformer_max_length = molecule_transformer_max_length
326
+ self.use_assay_metadata_features = use_assay_metadata_features
327
+ self.assay_metadata_dim = assay_metadata_dim
328
+ self._molecule_transformer_tokenizer = None
329
+ self._molecule_transformer_model = None
330
+ self._molecule_transformer_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
331
+
332
+ def _format_assay_query(self, assay_text: str) -> str:
333
+ return f"Instruct: {self.assay_task_description.strip()}\nQuery: {assay_text.strip()}"
334
+
335
+ def _build_assay_feature_array(self, assay_text: str) -> np.ndarray:
336
+ assay_features = self.assay_encoder.encode(
337
+ [self._format_assay_query(assay_text)],
338
+ batch_size=1,
339
+ normalize_embeddings=True,
340
+ show_progress_bar=False,
341
+ convert_to_numpy=True,
342
+ )[0].astype(np.float32)
343
+ if self.use_assay_metadata_features and self.assay_metadata_dim > 0:
344
+ metadata_vec = _assay_metadata_vector(assay_text, dim=self.assay_metadata_dim)
345
+ assay_features = np.concatenate([assay_features, metadata_vec.astype(np.float32)], axis=0)
346
+ return assay_features
347
+
348
+ def _ensure_molecule_transformer_loaded(self) -> None:
349
+ if not self.molecule_transformer_model_name or self._molecule_transformer_model is not None:
350
+ return
351
+ dtype = torch.float16 if self._molecule_transformer_device.type == "cuda" else torch.float32
352
+ self._molecule_transformer_tokenizer = AutoTokenizer.from_pretrained(
353
+ self.molecule_transformer_model_name,
354
+ trust_remote_code=True,
355
+ )
356
+ self._molecule_transformer_model = AutoModel.from_pretrained(
357
+ self.molecule_transformer_model_name,
358
+ trust_remote_code=True,
359
+ torch_dtype=dtype,
360
+ ).to(self._molecule_transformer_device)
361
+ self._molecule_transformer_model.eval()
362
+
363
+ def _encode_molecule_transformer_batch(self, smiles_values: list[str]) -> np.ndarray | None:
364
+ if not self.molecule_transformer_model_name:
365
+ return None
366
+ self._ensure_molecule_transformer_loaded()
367
+ assert self._molecule_transformer_model is not None
368
+ assert self._molecule_transformer_tokenizer is not None
369
+ outputs: list[np.ndarray] = []
370
+ batch_size = max(self.molecule_transformer_batch_size, 1)
371
+ with torch.no_grad():
372
+ for start in range(0, len(smiles_values), batch_size):
373
+ batch = smiles_values[start : start + batch_size]
374
+ encoded = self._molecule_transformer_tokenizer(
375
+ batch,
376
+ padding=True,
377
+ truncation=True,
378
+ max_length=self.molecule_transformer_max_length,
379
+ return_tensors="pt",
380
+ )
381
+ encoded = {key: value.to(self._molecule_transformer_device) for key, value in encoded.items()}
382
+ hidden = self._molecule_transformer_model(**encoded).last_hidden_state
383
+ mask = encoded["attention_mask"].unsqueeze(-1)
384
+ pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
385
+ outputs.append(pooled.detach().cpu().to(torch.float32).numpy())
386
+ return np.concatenate(outputs, axis=0).astype(np.float32)
387
+
388
+ def build_molecule_feature_matrix(self, smiles_values: list[str]) -> np.ndarray:
389
+ transformer_matrix = self._encode_molecule_transformer_batch(smiles_values)
390
+ rows: list[np.ndarray] = []
391
+ for idx, smiles in enumerate(smiles_values):
392
+ normalized = standardize_smiles_v2(smiles) or smiles
393
+ mol = Chem.MolFromSmiles(normalized)
394
+ if mol is None:
395
+ raise ValueError(f"Could not parse SMILES: {normalized}")
396
+ bit_blocks: list[np.ndarray] = [
397
+ _morgan_bits_from_mol(mol, radius=int(radius), n_bits=self.fingerprint_bits, use_chirality=self.use_chirality)
398
+ for radius in self.fingerprint_radii
399
+ ]
400
+ if self.use_maccs:
401
+ bit_blocks.append(_maccs_bits_from_mol(mol))
402
+ output_blocks: list[np.ndarray] = [np.concatenate(bit_blocks, axis=0).astype(np.float32)]
403
+ if self.use_rdkit_descriptors and self.descriptor_names:
404
+ dense = _molecule_descriptor_vector(mol, names=self.descriptor_names)
405
+ if self.descriptor_mean is not None and self.descriptor_std is not None:
406
+ dense = (dense - self.descriptor_mean) / self.descriptor_std
407
+ output_blocks.append(dense.astype(np.float32))
408
+ if transformer_matrix is not None:
409
+ output_blocks.append(np.asarray(transformer_matrix[idx], dtype=np.float32))
410
+ rows.append(np.concatenate(output_blocks, axis=0).astype(np.float32))
411
+ return np.stack(rows, axis=0)
412
+
413
+
414
+ def _load_sentence_transformer(model_name: str) -> SentenceTransformer:
415
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
416
+ encoder = SentenceTransformer(
417
+ model_name,
418
+ trust_remote_code=True,
419
+ model_kwargs={"torch_dtype": dtype},
420
+ )
421
+ if getattr(encoder, "tokenizer", None) is not None:
422
+ encoder.tokenizer.padding_side = "left"
423
+ return encoder
424
+
425
+
426
+ def _load_feature_spec(cfg: dict[str, Any], metadata: dict[str, Any], checkpoint: dict[str, Any]) -> dict[str, Any]:
427
+ spec = checkpoint.get("molecule_feature_spec") or metadata.get("molecule_feature_spec")
428
+ if spec:
429
+ return spec
430
+ radii = tuple(int(item) for item in (cfg.get("fingerprint_radii") or [cfg.get("fingerprint_radius", 2)]))
431
+ return {
432
+ "fingerprint_radii": list(radii),
433
+ "fingerprint_bits": int(cfg["fingerprint_bits"]),
434
+ "use_chirality": bool(cfg.get("use_chirality", False)),
435
+ "use_maccs": bool(cfg.get("use_maccs", False)),
436
+ "use_rdkit_descriptors": bool(cfg.get("use_rdkit_descriptors", False)),
437
+ "descriptor_names": [],
438
+ "descriptor_mean": None,
439
+ "descriptor_std": None,
440
+ "molecule_transformer_model_name": str(cfg.get("molecule_transformer_model_name") or ""),
441
+ "molecule_transformer_max_length": int(cfg.get("molecule_transformer_max_length", 128) or 128),
442
+ }
443
+
444
+
445
+ def load_compatibility_model(model_dir: str | Path) -> SpaceCompatibilityModel:
446
+ model_path = Path(model_dir)
447
+ checkpoint = torch.load(model_path / "best_model.pt", map_location="cpu", weights_only=False)
448
+ metadata = json.loads((model_path / "training_metadata.json").read_text())
449
+ cfg = metadata["config"]
450
+ feature_spec = _load_feature_spec(cfg, metadata, checkpoint)
451
+
452
+ encoder = _load_sentence_transformer(checkpoint.get("assay_model_name") or cfg["assay_model_name"])
453
+ assay_dim = int(checkpoint["model_state_dict"]["assay_proj.weight"].shape[1])
454
+ molecule_dim = int(checkpoint["model_state_dict"]["mol_proj.weight"].shape[1])
455
+ head = CompatibilityHead(
456
+ assay_dim=assay_dim,
457
+ molecule_dim=molecule_dim,
458
+ projection_dim=int(cfg["projection_dim"]),
459
+ hidden_dim=int(cfg["hidden_dim"]),
460
+ dropout=float(cfg["dropout"]),
461
+ )
462
+ load_result = head.load_state_dict(checkpoint["model_state_dict"], strict=False)
463
+ allowed_missing = {"mol_norm.weight", "mol_norm.bias"}
464
+ unexpected = set(load_result.unexpected_keys)
465
+ missing = set(load_result.missing_keys)
466
+ if unexpected or (missing - allowed_missing):
467
+ raise RuntimeError(
468
+ f"Checkpoint mismatch: unexpected={sorted(unexpected)} missing={sorted(missing)}"
469
+ )
470
+ return SpaceCompatibilityModel(
471
+ assay_encoder=encoder,
472
+ compatibility_head=head,
473
+ assay_task_description=checkpoint.get("assay_task_description") or cfg.get("assay_task_description", DEFAULT_ASSAY_TASK),
474
+ fingerprint_radii=tuple(int(item) for item in feature_spec.get("fingerprint_radii") or [2]),
475
+ fingerprint_bits=int(feature_spec.get("fingerprint_bits", cfg.get("fingerprint_bits", 2048))),
476
+ use_chirality=bool(feature_spec.get("use_chirality", cfg.get("use_chirality", False))),
477
+ use_maccs=bool(feature_spec.get("use_maccs", cfg.get("use_maccs", False))),
478
+ use_rdkit_descriptors=bool(feature_spec.get("use_rdkit_descriptors", cfg.get("use_rdkit_descriptors", False))),
479
+ descriptor_names=tuple(feature_spec.get("descriptor_names") or ()),
480
+ descriptor_mean=np.array(feature_spec["descriptor_mean"], dtype=np.float32) if feature_spec.get("descriptor_mean") is not None else None,
481
+ descriptor_std=np.array(feature_spec["descriptor_std"], dtype=np.float32) if feature_spec.get("descriptor_std") is not None else None,
482
+ molecule_transformer_model_name=str(feature_spec.get("molecule_transformer_model_name") or cfg.get("molecule_transformer_model_name") or ""),
483
+ molecule_transformer_batch_size=int(cfg.get("molecule_transformer_batch_size", 128) or 128),
484
+ molecule_transformer_max_length=int(feature_spec.get("molecule_transformer_max_length") or cfg.get("molecule_transformer_max_length", 128) or 128),
485
+ use_assay_metadata_features=bool(cfg.get("use_assay_metadata_features", False)),
486
+ assay_metadata_dim=int(cfg.get("assay_metadata_dim", 0) or 0),
487
+ )
488
+
489
+
490
+ @lru_cache(maxsize=1)
491
+ def load_compatibility_model_from_hub(model_repo_id: str) -> SpaceCompatibilityModel:
492
+ model_dir = snapshot_download(
493
+ repo_id=model_repo_id,
494
+ repo_type="model",
495
+ allow_patterns=["best_model.pt", "training_metadata.json", "README.md"],
496
+ )
497
+ return load_compatibility_model(model_dir)
498
+
499
+
500
+ def rank_compounds(
501
+ model: SpaceCompatibilityModel,
502
+ *,
503
+ assay_text: str,
504
+ smiles_list: list[str],
505
+ top_k: int | None = None,
506
+ ) -> list[dict[str, Any]]:
507
+ if not smiles_list:
508
+ return []
509
+ assay_features = model._build_assay_feature_array(assay_text)
510
+ assay_tensor = torch.from_numpy(assay_features.astype(np.float32)).unsqueeze(0)
511
+
512
+ valid_items: list[tuple[str, str]] = []
513
+ invalid_items: list[dict[str, Any]] = []
514
+ for raw_smiles in smiles_list:
515
+ standardized = standardize_smiles_v2(raw_smiles)
516
+ if standardized is None:
517
+ invalid_items.append(
518
+ {
519
+ "input_smiles": raw_smiles,
520
+ "canonical_smiles": None,
521
+ "smiles_hash": None,
522
+ "score": None,
523
+ "valid": False,
524
+ "error": "invalid_smiles",
525
+ }
526
+ )
527
+ continue
528
+ valid_items.append((raw_smiles, standardized))
529
+
530
+ ranked_items: list[dict[str, Any]] = []
531
+ if valid_items:
532
+ feature_matrix = model.build_molecule_feature_matrix([item[1] for item in valid_items])
533
+ candidate_tensor = torch.from_numpy(feature_matrix).unsqueeze(0)
534
+ with torch.no_grad():
535
+ logits, _, _ = model.compatibility_head.score_candidates(
536
+ assay_tensor.to(dtype=torch.float32),
537
+ candidate_tensor.to(dtype=torch.float32),
538
+ )
539
+ scores = logits.squeeze(0).cpu().numpy().tolist()
540
+ for (raw_smiles, canonical), score in zip(valid_items, scores, strict=True):
541
+ ranked_items.append(
542
+ {
543
+ "input_smiles": raw_smiles,
544
+ "canonical_smiles": canonical,
545
+ "smiles_hash": smiles_sha256(canonical),
546
+ "score": float(score),
547
+ "valid": True,
548
+ }
549
+ )
550
+ ranked_items.sort(key=lambda item: item["score"], reverse=True)
551
+ if top_k is not None and top_k > 0:
552
+ ranked_items = ranked_items[:top_k]
553
+
554
+ return ranked_items + invalid_items