lighteternal commited on
Commit
8a63cbb
·
verified ·
1 Parent(s): 0f8d9b1

Upload bioassayalign_compatibility.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. bioassayalign_compatibility.py +283 -148
bioassayalign_compatibility.py CHANGED
@@ -39,7 +39,9 @@ SECTION_ORDER = [
39
  "ASSAY_TYPE",
40
  "TARGET_UNIPROT",
41
  ]
42
- ASSAY_SECTION_RE = re.compile(r"\[(ASSAY_TITLE|DESCRIPTION|ORGANISM|READOUT|ASSAY_FORMAT|ASSAY_TYPE|TARGET_UNIPROT)\]\n")
 
 
43
  ORGANISM_ALIASES = {
44
  "9606": "homo_sapiens",
45
  "10090": "mus_musculus",
@@ -106,6 +108,10 @@ def serialize_assay_query(query: AssayQuery) -> str:
106
  return "\n\n".join(f"[{key}]\n{values[key]}" for key in SECTION_ORDER)
107
 
108
 
 
 
 
 
109
  def _parse_assay_sections(assay_text: str) -> dict[str, str]:
110
  sections = {key: "" for key in SECTION_ORDER}
111
  parts = ASSAY_SECTION_RE.split(assay_text)
@@ -219,7 +225,9 @@ def _molecule_descriptor_vector(mol, *, names: tuple[str, ...] = DEFAULT_DESCRIP
219
  fragments = Chem.GetMolFrags(mol)
220
  formal_charge = sum(int(atom.GetFormalCharge()) for atom in mol.GetAtoms())
221
  max_atomic_num = max(counts) if counts else 0
222
- metal_atom_count = sum(count for atomic_num, count in counts.items() if atomic_num not in ORGANIC_LIKE_ATOMIC_NUMBERS)
 
 
223
  halogen_count = sum(counts.get(item, 0) for item in (9, 17, 35, 53))
224
  aromatic_atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())
225
  values = {
@@ -258,177 +266,304 @@ def _molecule_descriptor_vector(mol, *, names: tuple[str, ...] = DEFAULT_DESCRIP
258
 
259
 
260
  class CompatibilityHead(nn.Module):
261
- def __init__(
262
- self,
263
- *,
264
- assay_dim: int,
265
- molecule_dim: int,
266
- projection_dim: int,
267
- hidden_dim: int,
268
- dropout: float,
269
- metadata_dim: int = 0,
270
- ) -> None:
271
  super().__init__()
272
- self.metadata_dim = metadata_dim
273
- assay_input_dim = assay_dim + metadata_dim
274
- self.assay_proj = nn.Sequential(
275
- nn.Linear(assay_input_dim, projection_dim),
276
- nn.GELU(),
277
- nn.Dropout(dropout),
278
- )
279
- self.molecule_proj = nn.Sequential(
280
- nn.Linear(molecule_dim, projection_dim),
281
- nn.GELU(),
282
- nn.Dropout(dropout),
283
- )
284
- self.scorer = nn.Sequential(
285
  nn.Linear(projection_dim * 4, hidden_dim),
286
  nn.GELU(),
287
  nn.Dropout(dropout),
288
  nn.Linear(hidden_dim, 1),
289
  )
290
-
291
- def forward(self, assay_vec: torch.Tensor, molecule_vec: torch.Tensor, assay_metadata: torch.Tensor | None = None) -> torch.Tensor:
292
- if assay_metadata is not None and assay_metadata.numel():
293
- assay_input = torch.cat([assay_vec, assay_metadata], dim=-1)
294
- else:
295
- assay_input = assay_vec
296
- assay_hidden = self.assay_proj(assay_input)
297
- molecule_hidden = self.molecule_proj(molecule_vec)
298
- interaction = torch.cat(
 
 
 
 
 
 
 
299
  [
300
- assay_hidden,
301
- molecule_hidden,
302
- assay_hidden * molecule_hidden,
303
- torch.abs(assay_hidden - molecule_hidden),
304
  ],
305
  dim=-1,
306
  )
307
- return self.scorer(interaction).squeeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
 
310
- class CompatibilityModel:
311
- def __init__(self, assay_encoder: SentenceTransformer, metadata: dict[str, Any], model_state_dict: dict[str, Any], *, device: str | None = None) -> None:
312
- self.metadata = metadata
313
- self.config = metadata["config"]
314
- self.feature_spec = metadata["molecule_feature_spec"]
315
- self.metadata_dim = int(self.config.get("assay_metadata_dim", 0))
316
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
317
  self.assay_encoder = assay_encoder
318
- self.assay_encoder.max_seq_length = 512
319
- self.assay_dim = int(self.assay_encoder.get_sentence_embedding_dimension())
320
- self.molecule_dim = int(metadata["feature_counts"]["molecule_dim"])
321
- self.head = CompatibilityHead(
322
- assay_dim=self.assay_dim,
323
- molecule_dim=self.molecule_dim,
324
- projection_dim=int(self.config["projection_dim"]),
325
- hidden_dim=int(self.config["hidden_dim"]),
326
- dropout=float(self.config["dropout"]),
327
- metadata_dim=self.metadata_dim,
328
- ).to(self.device)
329
- self.head.load_state_dict(model_state_dict)
330
- self.head.eval()
331
-
332
- def encode_assay(self, assay_text: str) -> tuple[torch.Tensor, torch.Tensor | None]:
333
- embedding = self.assay_encoder.encode(
334
- [assay_text],
335
- convert_to_numpy=True,
336
- show_progress_bar=False,
337
  normalize_embeddings=True,
338
- prompt_name="query",
339
- prompt=DEFAULT_ASSAY_TASK,
340
  )[0].astype(np.float32)
341
- assay_vec = torch.from_numpy(embedding).unsqueeze(0).to(self.device)
342
- metadata_vec = _assay_metadata_vector(assay_text, dim=self.metadata_dim)
343
- metadata_tensor = None
344
- if metadata_vec.size:
345
- metadata_tensor = torch.from_numpy(metadata_vec).unsqueeze(0).to(self.device)
346
- return assay_vec, metadata_tensor
347
-
348
- def score_feature_matrix(self, assay_text: str, feature_matrix: np.ndarray) -> np.ndarray:
349
- assay_vec, metadata_tensor = self.encode_assay(assay_text)
350
- molecule_tensor = torch.from_numpy(feature_matrix).to(self.device)
351
- with torch.inference_mode():
352
- assay_repeat = assay_vec.repeat(molecule_tensor.size(0), 1)
353
- metadata_repeat = metadata_tensor.repeat(molecule_tensor.size(0), 1) if metadata_tensor is not None else None
354
- scores = self.head(assay_repeat, molecule_tensor, metadata_repeat)
355
- return scores.detach().cpu().numpy()
356
-
357
-
358
- def build_molecule_feature_vector(smiles: str, feature_spec: dict[str, Any]) -> np.ndarray | None:
359
- standardized = standardize_smiles_v2(smiles)
360
- if standardized is None:
361
- return None
362
- mol = Chem.MolFromSmiles(standardized)
363
- if mol is None:
364
- return None
365
- parts: list[np.ndarray] = []
366
- for radius in feature_spec.get("fingerprint_radii", [2, 3]):
367
- parts.append(
368
- _morgan_bits_from_mol(
369
- mol,
370
- radius=int(radius),
371
- n_bits=int(feature_spec.get("fingerprint_bits", 2048)),
372
- use_chirality=bool(feature_spec.get("use_chirality", True)),
373
- ).astype(np.float32)
374
- )
375
- if feature_spec.get("use_maccs", True):
376
- parts.append(_maccs_bits_from_mol(mol).astype(np.float32))
377
- if feature_spec.get("use_rdkit_descriptors", True):
378
- descriptor_values = _molecule_descriptor_vector(
379
- mol,
380
- names=tuple(feature_spec.get("descriptor_names", DEFAULT_DESCRIPTOR_NAMES)),
381
- )
382
- descriptor_mean = np.asarray(feature_spec["descriptor_mean"], dtype=np.float32)
383
- descriptor_std = np.asarray(feature_spec["descriptor_std"], dtype=np.float32)
384
- parts.append(((descriptor_values - descriptor_mean) / (descriptor_std + 1e-6)).astype(np.float32))
385
- if not parts:
386
- return None
387
- return np.concatenate(parts, axis=0).astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
 
390
- def load_compatibility_model(model_dir: str | Path, *, device: str | None = None) -> CompatibilityModel:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  model_path = Path(model_dir)
392
- training_metadata = json.loads((model_path / "training_metadata.json").read_text())
393
  checkpoint = torch.load(model_path / "best_model.pt", map_location="cpu", weights_only=False)
394
- assay_model_name = training_metadata["config"]["assay_model_name"]
395
- assay_encoder = SentenceTransformer(assay_model_name, device=device or ("cuda" if torch.cuda.is_available() else "cpu"))
396
- return CompatibilityModel(assay_encoder, training_metadata, checkpoint["model_state_dict"], device=device)
397
-
398
-
399
- def load_compatibility_model_from_hub(repo_id: str, *, device: str | None = None) -> CompatibilityModel:
400
- snapshot_path = snapshot_download(repo_id=repo_id, repo_type="model", allow_patterns=["best_model.pt", "training_metadata.json"])
401
- return load_compatibility_model(snapshot_path, device=device)
402
-
403
-
404
- def rank_compounds(model: CompatibilityModel, assay_text: str, smiles_list: list[str], *, top_k: int | None = None) -> list[dict[str, Any]]:
405
- valid_inputs: list[tuple[str, str, np.ndarray]] = []
406
- invalid_rows: list[dict[str, Any]] = []
407
- for item in smiles_list:
408
- feature_vec = build_molecule_feature_vector(item, model.feature_spec)
409
- standardized = standardize_smiles_v2(item)
410
- if feature_vec is None or standardized is None:
411
- invalid_rows.append({"input_smiles": item, "valid": False, "error": "invalid_smiles"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  continue
413
- valid_inputs.append((item, standardized, feature_vec))
414
- valid_rows: list[dict[str, Any]] = []
415
- if valid_inputs:
416
- feature_matrix = np.stack([entry[2] for entry in valid_inputs], axis=0).astype(np.float32)
417
- scores = model.score_feature_matrix(assay_text, feature_matrix)
418
- for (input_smiles, standardized, _), score in zip(valid_inputs, scores):
419
- valid_rows.append(
 
 
 
 
 
 
 
420
  {
421
- "input_smiles": input_smiles,
422
- "canonical_smiles": standardized,
423
- "smiles_hash": smiles_sha256(standardized),
424
  "score": float(score),
425
  "valid": True,
426
  }
427
  )
428
- valid_rows.sort(key=lambda item: item["score"], reverse=True)
429
- if top_k:
430
- valid_rows = valid_rows[:top_k]
431
- return valid_rows + invalid_rows
 
432
 
433
 
434
  def list_softmax_scores(scores: list[float], temperature: float = 1.0) -> list[float]:
 
39
  "ASSAY_TYPE",
40
  "TARGET_UNIPROT",
41
  ]
42
+ ASSAY_SECTION_RE = re.compile(
43
+ r"\[(ASSAY_TITLE|DESCRIPTION|ORGANISM|READOUT|ASSAY_FORMAT|ASSAY_TYPE|TARGET_UNIPROT)\]\n"
44
+ )
45
  ORGANISM_ALIASES = {
46
  "9606": "homo_sapiens",
47
  "10090": "mus_musculus",
 
108
  return "\n\n".join(f"[{key}]\n{values[key]}" for key in SECTION_ORDER)
109
 
110
 
111
+ def _format_assay_query(assay_text: str, task_description: str) -> str:
112
+ return f"Instruct: {task_description.strip()}\nQuery: {assay_text.strip()}"
113
+
114
+
115
  def _parse_assay_sections(assay_text: str) -> dict[str, str]:
116
  sections = {key: "" for key in SECTION_ORDER}
117
  parts = ASSAY_SECTION_RE.split(assay_text)
 
225
  fragments = Chem.GetMolFrags(mol)
226
  formal_charge = sum(int(atom.GetFormalCharge()) for atom in mol.GetAtoms())
227
  max_atomic_num = max(counts) if counts else 0
228
+ metal_atom_count = sum(
229
+ count for atomic_num, count in counts.items() if atomic_num not in ORGANIC_LIKE_ATOMIC_NUMBERS
230
+ )
231
  halogen_count = sum(counts.get(item, 0) for item in (9, 17, 35, 53))
232
  aromatic_atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())
233
  values = {
 
266
 
267
 
268
  class CompatibilityHead(nn.Module):
269
+ def __init__(self, *, assay_dim: int, molecule_dim: int, projection_dim: int, hidden_dim: int, dropout: float) -> None:
 
 
 
 
 
 
 
 
 
270
  super().__init__()
271
+ self.assay_norm = nn.LayerNorm(assay_dim)
272
+ self.assay_proj = nn.Linear(assay_dim, projection_dim)
273
+ self.mol_norm = nn.LayerNorm(molecule_dim)
274
+ self.mol_proj = nn.Linear(molecule_dim, projection_dim, bias=False)
275
+ self.score_mlp = nn.Sequential(
 
 
 
 
 
 
 
 
276
  nn.Linear(projection_dim * 4, hidden_dim),
277
  nn.GELU(),
278
  nn.Dropout(dropout),
279
  nn.Linear(hidden_dim, 1),
280
  )
281
+ self.dot_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
282
+
283
+ def encode_assay(self, assay_features: torch.Tensor) -> torch.Tensor:
284
+ vec = self.assay_proj(self.assay_norm(assay_features))
285
+ return F.normalize(vec, p=2, dim=-1)
286
+
287
+ def encode_molecule(self, molecule_features: torch.Tensor) -> torch.Tensor:
288
+ vec = self.mol_proj(self.mol_norm(molecule_features))
289
+ return F.normalize(vec, p=2, dim=-1)
290
+
291
+ def score_candidates(self, assay_features: torch.Tensor, candidate_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
292
+ assay_vec = self.encode_assay(assay_features)
293
+ mol_vec = self.encode_molecule(candidate_features)
294
+ assay_expand = assay_vec.unsqueeze(1).expand(-1, mol_vec.shape[1], -1)
295
+ dot_scores = (assay_expand * mol_vec).sum(dim=-1)
296
+ mlp_input = torch.cat(
297
  [
298
+ assay_expand,
299
+ mol_vec,
300
+ assay_expand * mol_vec,
301
+ torch.abs(assay_expand - mol_vec),
302
  ],
303
  dim=-1,
304
  )
305
+ mlp_scores = self.score_mlp(mlp_input).squeeze(-1)
306
+ logits = dot_scores * self.dot_scale + mlp_scores
307
+ return logits, assay_vec, mol_vec
308
+
309
+ def score_pairs(self, assay_features: torch.Tensor, molecule_features: torch.Tensor) -> torch.Tensor:
310
+ assay_vec = self.encode_assay(assay_features)
311
+ mol_vec = self.encode_molecule(molecule_features)
312
+ dot_scores = (assay_vec * mol_vec).sum(dim=-1)
313
+ mlp_input = torch.cat(
314
+ [assay_vec, mol_vec, assay_vec * mol_vec, torch.abs(assay_vec - mol_vec)],
315
+ dim=-1,
316
+ )
317
+ mlp_scores = self.score_mlp(mlp_input).squeeze(-1)
318
+ return dot_scores * self.dot_scale + mlp_scores
319
 
320
 
321
+ class BioAssayAlignCompatibilityModel:
322
+ def __init__(
323
+ self,
324
+ assay_encoder: SentenceTransformer,
325
+ compatibility_head: CompatibilityHead,
326
+ *,
327
+ assay_task_description: str,
328
+ fingerprint_radii: tuple[int, ...],
329
+ fingerprint_bits: int,
330
+ use_chirality: bool,
331
+ use_maccs: bool,
332
+ use_rdkit_descriptors: bool,
333
+ descriptor_names: tuple[str, ...],
334
+ descriptor_mean: np.ndarray | None,
335
+ descriptor_std: np.ndarray | None,
336
+ use_assay_metadata_features: bool,
337
+ assay_metadata_dim: int,
338
+ ) -> None:
339
  self.assay_encoder = assay_encoder
340
+ self.compatibility_head = compatibility_head.eval()
341
+ self.assay_task_description = assay_task_description
342
+ self.fingerprint_radii = fingerprint_radii
343
+ self.fingerprint_bits = fingerprint_bits
344
+ self.use_chirality = use_chirality
345
+ self.use_maccs = use_maccs
346
+ self.use_rdkit_descriptors = use_rdkit_descriptors
347
+ self.descriptor_names = descriptor_names
348
+ self.descriptor_mean = descriptor_mean
349
+ self.descriptor_std = descriptor_std
350
+ self.use_assay_metadata_features = use_assay_metadata_features
351
+ self.assay_metadata_dim = assay_metadata_dim
352
+
353
+ def _build_assay_feature_array(self, assay_text: str) -> np.ndarray:
354
+ query = _format_assay_query(assay_text, self.assay_task_description)
355
+ assay_features = self.assay_encoder.encode(
356
+ [query],
357
+ batch_size=1,
 
358
  normalize_embeddings=True,
359
+ show_progress_bar=False,
360
+ convert_to_numpy=True,
361
  )[0].astype(np.float32)
362
+ if self.use_assay_metadata_features and self.assay_metadata_dim > 0:
363
+ metadata_vec = _assay_metadata_vector(assay_text, dim=self.assay_metadata_dim)
364
+ assay_features = np.concatenate([assay_features, metadata_vec.astype(np.float32)], axis=0)
365
+ return assay_features
366
+
367
+ def build_molecule_feature_matrix(self, smiles_values: list[str]) -> np.ndarray:
368
+ rows: list[np.ndarray] = []
369
+ for smiles in smiles_values:
370
+ rows.append(
371
+ _smiles_to_molecule_features(
372
+ smiles,
373
+ radii=self.fingerprint_radii,
374
+ n_bits=self.fingerprint_bits,
375
+ use_chirality=self.use_chirality,
376
+ use_maccs=self.use_maccs,
377
+ use_rdkit_descriptors=self.use_rdkit_descriptors,
378
+ descriptor_names=self.descriptor_names,
379
+ descriptor_mean=self.descriptor_mean,
380
+ descriptor_std=self.descriptor_std,
381
+ )
382
+ )
383
+ return np.stack(rows, axis=0).astype(np.float32)
384
+
385
+ def score(self, assay_text: str, smiles: str) -> float:
386
+ assay_features = self._build_assay_feature_array(assay_text)
387
+ molecule_features = self.build_molecule_feature_matrix([smiles])[0]
388
+ assay_tensor = torch.from_numpy(assay_features).unsqueeze(0)
389
+ molecule_tensor = torch.from_numpy(molecule_features).unsqueeze(0)
390
+ with torch.no_grad():
391
+ score = self.compatibility_head.score_pairs(assay_tensor, molecule_tensor)
392
+ return float(score.item())
393
+
394
+
395
+ def _load_sentence_transformer(model_name: str):
396
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
397
+ encoder = SentenceTransformer(
398
+ model_name,
399
+ trust_remote_code=True,
400
+ model_kwargs={"torch_dtype": dtype},
401
+ )
402
+ if getattr(encoder, "tokenizer", None) is not None:
403
+ encoder.tokenizer.padding_side = "left"
404
+ return encoder
405
+
406
+
407
+ def _load_feature_spec(cfg: dict[str, Any], metadata: dict[str, Any], checkpoint: dict[str, Any]) -> dict[str, Any]:
408
+ spec = checkpoint.get("molecule_feature_spec") or metadata.get("molecule_feature_spec")
409
+ if spec:
410
+ return spec
411
+ radii = tuple(int(item) for item in (cfg.get("fingerprint_radii") or [cfg.get("fingerprint_radius", 2)]))
412
+ return {
413
+ "fingerprint_radii": list(radii),
414
+ "fingerprint_bits": int(cfg["fingerprint_bits"]),
415
+ "use_chirality": bool(cfg.get("use_chirality", False)),
416
+ "use_maccs": bool(cfg.get("use_maccs", False)),
417
+ "use_rdkit_descriptors": bool(cfg.get("use_rdkit_descriptors", False)),
418
+ "descriptor_names": [],
419
+ "descriptor_mean": None,
420
+ "descriptor_std": None,
421
+ }
422
 
423
 
424
+ def _smiles_to_molecule_features(
425
+ smiles: str,
426
+ *,
427
+ radii: tuple[int, ...],
428
+ n_bits: int,
429
+ use_chirality: bool,
430
+ use_maccs: bool,
431
+ use_rdkit_descriptors: bool,
432
+ descriptor_names: tuple[str, ...],
433
+ descriptor_mean: np.ndarray | None,
434
+ descriptor_std: np.ndarray | None,
435
+ ) -> np.ndarray:
436
+ normalized = standardize_smiles_v2(smiles) or smiles
437
+ mol = Chem.MolFromSmiles(normalized)
438
+ if mol is None:
439
+ raise ValueError(f"Could not parse SMILES: {normalized}")
440
+ bit_blocks: list[np.ndarray] = [
441
+ _morgan_bits_from_mol(mol, radius=int(radius), n_bits=n_bits, use_chirality=use_chirality)
442
+ for radius in radii
443
+ ]
444
+ if use_maccs:
445
+ bit_blocks.append(_maccs_bits_from_mol(mol))
446
+ output_blocks: list[np.ndarray] = [np.concatenate(bit_blocks, axis=0).astype(np.float32)]
447
+ if use_rdkit_descriptors and descriptor_names:
448
+ dense = _molecule_descriptor_vector(mol, names=descriptor_names)
449
+ if descriptor_mean is not None and descriptor_std is not None:
450
+ dense = (dense - descriptor_mean) / descriptor_std
451
+ output_blocks.append(dense.astype(np.float32))
452
+ return np.concatenate(output_blocks, axis=0).astype(np.float32)
453
+
454
+
455
+ def load_compatibility_model(model_dir: str | Path) -> BioAssayAlignCompatibilityModel:
456
  model_path = Path(model_dir)
 
457
  checkpoint = torch.load(model_path / "best_model.pt", map_location="cpu", weights_only=False)
458
+ metadata = json.loads((model_path / "training_metadata.json").read_text())
459
+ cfg = metadata["config"]
460
+ feature_spec = _load_feature_spec(cfg, metadata, checkpoint)
461
+
462
+ encoder = _load_sentence_transformer(checkpoint.get("assay_model_name") or cfg["assay_model_name"])
463
+ assay_dim = int(checkpoint["model_state_dict"]["assay_proj.weight"].shape[1])
464
+ molecule_dim = int(checkpoint["model_state_dict"]["mol_proj.weight"].shape[1])
465
+ head = CompatibilityHead(
466
+ assay_dim=assay_dim,
467
+ molecule_dim=molecule_dim,
468
+ projection_dim=int(cfg["projection_dim"]),
469
+ hidden_dim=int(cfg["hidden_dim"]),
470
+ dropout=float(cfg["dropout"]),
471
+ )
472
+ load_result = head.load_state_dict(checkpoint["model_state_dict"], strict=False)
473
+ allowed_missing = {"mol_norm.weight", "mol_norm.bias"}
474
+ unexpected = set(load_result.unexpected_keys)
475
+ missing = set(load_result.missing_keys)
476
+ if unexpected or (missing - allowed_missing):
477
+ raise RuntimeError(
478
+ "Compatibility checkpoint load mismatch: "
479
+ f"unexpected={sorted(unexpected)} missing={sorted(missing)}"
480
+ )
481
+ return BioAssayAlignCompatibilityModel(
482
+ assay_encoder=encoder,
483
+ compatibility_head=head,
484
+ assay_task_description=checkpoint.get("assay_task_description") or cfg["assay_task_description"],
485
+ fingerprint_radii=tuple(int(item) for item in feature_spec.get("fingerprint_radii") or [2]),
486
+ fingerprint_bits=int(feature_spec.get("fingerprint_bits", cfg.get("fingerprint_bits", 2048))),
487
+ use_chirality=bool(feature_spec.get("use_chirality", cfg.get("use_chirality", False))),
488
+ use_maccs=bool(feature_spec.get("use_maccs", False)),
489
+ use_rdkit_descriptors=bool(feature_spec.get("use_rdkit_descriptors", False)),
490
+ descriptor_names=tuple(feature_spec.get("descriptor_names") or ()),
491
+ descriptor_mean=np.array(feature_spec["descriptor_mean"], dtype=np.float32)
492
+ if feature_spec.get("descriptor_mean") is not None
493
+ else None,
494
+ descriptor_std=np.array(feature_spec["descriptor_std"], dtype=np.float32)
495
+ if feature_spec.get("descriptor_std") is not None
496
+ else None,
497
+ use_assay_metadata_features=bool(cfg.get("use_assay_metadata_features", False)),
498
+ assay_metadata_dim=int(cfg.get("assay_metadata_dim", 0) or 0),
499
+ )
500
+
501
+
502
+ def load_compatibility_model_from_hub(repo_id: str) -> BioAssayAlignCompatibilityModel:
503
+ snapshot_path = snapshot_download(
504
+ repo_id=repo_id,
505
+ repo_type="model",
506
+ allow_patterns=["best_model.pt", "training_metadata.json"],
507
+ )
508
+ return load_compatibility_model(snapshot_path)
509
+
510
+
511
+ def rank_compounds(
512
+ model: BioAssayAlignCompatibilityModel,
513
+ *,
514
+ assay_text: str,
515
+ smiles_list: list[str],
516
+ top_k: int | None = None,
517
+ ) -> list[dict[str, Any]]:
518
+ if not smiles_list:
519
+ return []
520
+
521
+ assay_features = model._build_assay_feature_array(assay_text)
522
+ assay_tensor = torch.from_numpy(assay_features.astype(np.float32)).unsqueeze(0)
523
+
524
+ valid_items: list[tuple[str, str]] = []
525
+ invalid_items: list[dict[str, Any]] = []
526
+ for raw_smiles in smiles_list:
527
+ standardized = standardize_smiles_v2(raw_smiles)
528
+ if standardized is None:
529
+ invalid_items.append(
530
+ {
531
+ "input_smiles": raw_smiles,
532
+ "canonical_smiles": None,
533
+ "smiles_hash": None,
534
+ "score": None,
535
+ "valid": False,
536
+ "error": "invalid_smiles",
537
+ }
538
+ )
539
  continue
540
+ valid_items.append((raw_smiles, standardized))
541
+
542
+ ranked_items: list[dict[str, Any]] = []
543
+ if valid_items:
544
+ feature_matrix = model.build_molecule_feature_matrix([item[1] for item in valid_items])
545
+ candidate_tensor = torch.from_numpy(feature_matrix).unsqueeze(0)
546
+ with torch.no_grad():
547
+ logits, _, _ = model.compatibility_head.score_candidates(
548
+ assay_tensor.to(dtype=torch.float32),
549
+ candidate_tensor.to(dtype=torch.float32),
550
+ )
551
+ scores = logits.squeeze(0).cpu().numpy().tolist()
552
+ for (raw_smiles, canonical), score in zip(valid_items, scores, strict=True):
553
+ ranked_items.append(
554
  {
555
+ "input_smiles": raw_smiles,
556
+ "canonical_smiles": canonical,
557
+ "smiles_hash": smiles_sha256(canonical),
558
  "score": float(score),
559
  "valid": True,
560
  }
561
  )
562
+ ranked_items.sort(key=lambda item: item["score"], reverse=True)
563
+ if top_k is not None and top_k > 0:
564
+ ranked_items = ranked_items[:top_k]
565
+
566
+ return ranked_items + invalid_items
567
 
568
 
569
  def list_softmax_scores(scores: list[float], temperature: float = 1.0) -> list[float]: