from __future__ import annotations import contextlib import hashlib import io import json import os import re from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Any import numpy as np import torch import torch.nn.functional as F from huggingface_hub import snapshot_download from huggingface_hub.utils import disable_progress_bars from rdkit import Chem, DataStructs, RDLogger from rdkit.Chem import AllChem, Crippen, Descriptors, Lipinski, MACCSkeys, rdMolDescriptors from rdkit.Chem.MolStandardize import rdMolStandardize from sentence_transformers import SentenceTransformer from torch import nn from transformers import AutoModel, AutoTokenizer from transformers.utils import logging as transformers_logging os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") disable_progress_bars() transformers_logging.set_verbosity_error() RDLogger.DisableLog("rdApp.*") DEFAULT_ASSAY_TASK = ( "Given a bioassay description and metadata, represent the assay for ranking compatible small molecules." ) DEFAULT_DESCRIPTOR_NAMES = ( "mol_wt", "logp", "tpsa", "heavy_atoms", "hbd", "hba", "rot_bonds", "ring_count", "aromatic_rings", "aliphatic_rings", "saturated_rings", "fraction_csp3", "heteroatoms", "amide_bonds", "fragments", "formal_charge", "max_atomic_num", "metal_atom_count", "halogen_count", "nitrogen_count", "oxygen_count", "sulfur_count", "phosphorus_count", "fluorine_count", "chlorine_count", "bromine_count", "iodine_count", "aromatic_atom_count", "spiro_atoms", "bridgehead_atoms", ) ORGANIC_LIKE_ATOMIC_NUMBERS = {1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53} SECTION_ORDER = [ "ASSAY_TITLE", "DESCRIPTION", "ORGANISM", "READOUT", "ASSAY_FORMAT", "ASSAY_TYPE", "TARGET_UNIPROT", ] ASSAY_SECTION_RE = re.compile(r"\[(ASSAY_TITLE|DESCRIPTION|ORGANISM|READOUT|ASSAY_FORMAT|ASSAY_TYPE|TARGET_UNIPROT)\]\n") ORGANISM_ALIASES = { "9606": "homo_sapiens", "10090": "mus_musculus", "10116": "rattus_norvegicus", "4932": "saccharomyces_cerevisiae", } @dataclass class AssayQuery: title: str = "" description: str = "" organism: str = "" readout: str = "" assay_format: str = "" assay_type: str = "" target_uniprot: list[str] | None = None def smiles_sha256(smiles: str) -> str: return hashlib.sha256(smiles.encode("utf-8")).hexdigest() @contextlib.contextmanager def _silent_imports(): buffer = io.StringIO() with contextlib.redirect_stdout(buffer), contextlib.redirect_stderr(buffer): yield @lru_cache(maxsize=1_000_000) def _standardize_smiles_v2_cached(smiles: str) -> str | None: mol = Chem.MolFromSmiles(smiles) if mol is None: return None try: mol = rdMolStandardize.Cleanup(mol) mol = rdMolStandardize.FragmentParent(mol) mol = rdMolStandardize.Uncharger().uncharge(mol) mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol) Chem.SanitizeMol(mol) except Exception: return None if mol.GetNumHeavyAtoms() < 2: return None standardized = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True) if not standardized or "." in standardized: return None return standardized def standardize_smiles_v2(smiles: str | None) -> str | None: if not smiles: return None token = smiles.strip() if not token: return None return _standardize_smiles_v2_cached(token) def serialize_assay_query(query: AssayQuery) -> str: targets = ", ".join(query.target_uniprot or []) values = { "ASSAY_TITLE": query.title.strip(), "DESCRIPTION": query.description.strip(), "ORGANISM": query.organism.strip(), "READOUT": query.readout.strip(), "ASSAY_FORMAT": query.assay_format.strip(), "ASSAY_TYPE": query.assay_type.strip(), "TARGET_UNIPROT": targets.strip(), } return "\n\n".join(f"[{key}]\n{values[key]}" for key in SECTION_ORDER) def _parse_assay_sections(assay_text: str) -> dict[str, str]: sections = {key: "" for key in SECTION_ORDER} parts = ASSAY_SECTION_RE.split(assay_text) for idx in range(1, len(parts), 2): key = parts[idx] value = parts[idx + 1] if idx + 1 < len(parts) else "" if key in sections: sections[key] = value.strip() return sections def _hash_bucket(value: str, dim: int) -> int: return abs(hash(value)) % max(dim, 1) def _normalize_metadata_token(value: str) -> str: return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_") def _normalize_organism_token(value: str) -> str: raw = value.strip() if not raw: return "" aliased = ORGANISM_ALIASES.get(raw, raw) return _normalize_metadata_token(aliased) def _assay_metadata_vector(assay_text: str, *, dim: int) -> np.ndarray: if dim <= 0: return np.zeros((0,), dtype=np.float32) sections = _parse_assay_sections(assay_text) tokens: list[str] = [] organism = _normalize_organism_token(sections.get("ORGANISM", "")) if organism: tokens.append(f"organism:{organism}") for key in ("READOUT", "ASSAY_FORMAT", "ASSAY_TYPE"): value = _normalize_metadata_token(sections.get(key, "")) if value: tokens.append(f"{key.lower()}:{value}") for target in sections.get("TARGET_UNIPROT", "").split(","): token = target.strip().upper() if token: tokens.append(f"target:{token}") vec = np.zeros((dim,), dtype=np.float32) for token in tokens: vec[_hash_bucket(token, dim)] += 1.0 norm = float(np.linalg.norm(vec)) if norm > 0: vec /= norm return vec def _morgan_bits_from_mol(mol, *, radius: int, n_bits: int, use_chirality: bool) -> np.ndarray: fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits, useChirality=use_chirality) arr = np.zeros((n_bits,), dtype=np.uint8) DataStructs.ConvertToNumpyArray(fp, arr) return arr def _maccs_bits_from_mol(mol) -> np.ndarray: fp = MACCSkeys.GenMACCSKeys(mol) arr = np.zeros((fp.GetNumBits(),), dtype=np.uint8) DataStructs.ConvertToNumpyArray(fp, arr) return arr def _count_atomic_nums(mol) -> dict[int, int]: counts: dict[int, int] = {} for atom in mol.GetAtoms(): atomic_num = int(atom.GetAtomicNum()) counts[atomic_num] = counts.get(atomic_num, 0) + 1 return counts def _molecule_descriptor_vector(mol, *, names: tuple[str, ...] = DEFAULT_DESCRIPTOR_NAMES) -> np.ndarray: counts = _count_atomic_nums(mol) fragments = Chem.GetMolFrags(mol) formal_charge = sum(int(atom.GetFormalCharge()) for atom in mol.GetAtoms()) max_atomic_num = max(counts) if counts else 0 metal_atom_count = sum(count for atomic_num, count in counts.items() if atomic_num not in ORGANIC_LIKE_ATOMIC_NUMBERS) halogen_count = sum(counts.get(item, 0) for item in (9, 17, 35, 53)) aromatic_atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic()) values = { "mol_wt": float(Descriptors.MolWt(mol)), "logp": float(Crippen.MolLogP(mol)), "tpsa": float(rdMolDescriptors.CalcTPSA(mol)), "heavy_atoms": float(mol.GetNumHeavyAtoms()), "hbd": float(Lipinski.NumHDonors(mol)), "hba": float(Lipinski.NumHAcceptors(mol)), "rot_bonds": float(Lipinski.NumRotatableBonds(mol)), "ring_count": float(rdMolDescriptors.CalcNumRings(mol)), "aromatic_rings": float(rdMolDescriptors.CalcNumAromaticRings(mol)), "aliphatic_rings": float(rdMolDescriptors.CalcNumAliphaticRings(mol)), "saturated_rings": float(rdMolDescriptors.CalcNumSaturatedRings(mol)), "fraction_csp3": float(rdMolDescriptors.CalcFractionCSP3(mol)), "heteroatoms": float(rdMolDescriptors.CalcNumHeteroatoms(mol)), "amide_bonds": float(rdMolDescriptors.CalcNumAmideBonds(mol)), "fragments": float(len(fragments)), "formal_charge": float(formal_charge), "max_atomic_num": float(max_atomic_num), "metal_atom_count": float(metal_atom_count), "halogen_count": float(halogen_count), "nitrogen_count": float(counts.get(7, 0)), "oxygen_count": float(counts.get(8, 0)), "sulfur_count": float(counts.get(16, 0)), "phosphorus_count": float(counts.get(15, 0)), "fluorine_count": float(counts.get(9, 0)), "chlorine_count": float(counts.get(17, 0)), "bromine_count": float(counts.get(35, 0)), "iodine_count": float(counts.get(53, 0)), "aromatic_atom_count": float(aromatic_atom_count), "spiro_atoms": float(rdMolDescriptors.CalcNumSpiroAtoms(mol)), "bridgehead_atoms": float(rdMolDescriptors.CalcNumBridgeheadAtoms(mol)), } return np.array([values[name] for name in names], dtype=np.float32) def molecule_ui_metrics(smiles: str) -> dict[str, float | int]: canonical = standardize_smiles_v2(smiles) or smiles mol = Chem.MolFromSmiles(canonical) if mol is None: return { "mol_wt": 0.0, "logp": 0.0, "tpsa": 0.0, "heavy_atoms": 0, } return { "mol_wt": float(Descriptors.MolWt(mol)), "logp": float(Crippen.MolLogP(mol)), "tpsa": float(rdMolDescriptors.CalcTPSA(mol)), "heavy_atoms": int(mol.GetNumHeavyAtoms()), } class CompatibilityHead(nn.Module): def __init__(self, *, assay_dim: int, molecule_dim: int, projection_dim: int, hidden_dim: int, dropout: float) -> None: super().__init__() self.assay_norm = nn.LayerNorm(assay_dim) self.assay_proj = nn.Linear(assay_dim, projection_dim) self.mol_norm = nn.LayerNorm(molecule_dim) self.mol_proj = nn.Linear(molecule_dim, projection_dim, bias=False) self.score_mlp = nn.Sequential( nn.Linear(projection_dim * 4, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, 1), ) self.dot_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) def encode_assay(self, assay_features: torch.Tensor) -> torch.Tensor: vec = self.assay_proj(self.assay_norm(assay_features)) return F.normalize(vec, p=2, dim=-1) def encode_molecule(self, molecule_features: torch.Tensor) -> torch.Tensor: vec = self.mol_proj(self.mol_norm(molecule_features)) return F.normalize(vec, p=2, dim=-1) def score_candidates(self, assay_features: torch.Tensor, candidate_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assay_vec = self.encode_assay(assay_features) mol_vec = self.encode_molecule(candidate_features) assay_expand = assay_vec.unsqueeze(1).expand(-1, mol_vec.shape[1], -1) dot_scores = (assay_expand * mol_vec).sum(dim=-1) mlp_input = torch.cat( [assay_expand, mol_vec, assay_expand * mol_vec, torch.abs(assay_expand - mol_vec)], dim=-1, ) mlp_scores = self.score_mlp(mlp_input).squeeze(-1) logits = dot_scores * self.dot_scale + mlp_scores return logits, assay_vec, mol_vec class SpaceCompatibilityModel: def __init__( self, *, assay_encoder: SentenceTransformer, compatibility_head: CompatibilityHead, assay_task_description: str, fingerprint_radii: tuple[int, ...], fingerprint_bits: int, use_chirality: bool, use_maccs: bool, use_rdkit_descriptors: bool, descriptor_names: tuple[str, ...], descriptor_mean: np.ndarray | None, descriptor_std: np.ndarray | None, molecule_transformer_model_name: str, molecule_transformer_batch_size: int, molecule_transformer_max_length: int, use_assay_metadata_features: bool, assay_metadata_dim: int, ) -> None: self.assay_encoder = assay_encoder self.compatibility_head = compatibility_head.eval() self.assay_task_description = assay_task_description self.fingerprint_radii = fingerprint_radii self.fingerprint_bits = fingerprint_bits self.use_chirality = use_chirality self.use_maccs = use_maccs self.use_rdkit_descriptors = use_rdkit_descriptors self.descriptor_names = descriptor_names self.descriptor_mean = descriptor_mean self.descriptor_std = descriptor_std self.molecule_transformer_model_name = molecule_transformer_model_name self.molecule_transformer_batch_size = molecule_transformer_batch_size self.molecule_transformer_max_length = molecule_transformer_max_length self.use_assay_metadata_features = use_assay_metadata_features self.assay_metadata_dim = assay_metadata_dim self._molecule_transformer_tokenizer = None self._molecule_transformer_model = None self._molecule_transformer_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _format_assay_query(self, assay_text: str) -> str: return f"Instruct: {self.assay_task_description.strip()}\nQuery: {assay_text.strip()}" def _build_assay_feature_array(self, assay_text: str) -> np.ndarray: assay_features = self.assay_encoder.encode( [self._format_assay_query(assay_text)], batch_size=1, normalize_embeddings=True, show_progress_bar=False, convert_to_numpy=True, )[0].astype(np.float32) if self.use_assay_metadata_features and self.assay_metadata_dim > 0: metadata_vec = _assay_metadata_vector(assay_text, dim=self.assay_metadata_dim) assay_features = np.concatenate([assay_features, metadata_vec.astype(np.float32)], axis=0) return assay_features def _ensure_molecule_transformer_loaded(self) -> None: if not self.molecule_transformer_model_name or self._molecule_transformer_model is not None: return dtype = torch.float16 if self._molecule_transformer_device.type == "cuda" else torch.float32 with _silent_imports(): self._molecule_transformer_tokenizer = AutoTokenizer.from_pretrained( self.molecule_transformer_model_name, trust_remote_code=True, ) self._molecule_transformer_model = AutoModel.from_pretrained( self.molecule_transformer_model_name, trust_remote_code=True, torch_dtype=dtype, ).to(self._molecule_transformer_device) self._molecule_transformer_model.eval() def _encode_molecule_transformer_batch(self, smiles_values: list[str]) -> np.ndarray | None: if not self.molecule_transformer_model_name: return None self._ensure_molecule_transformer_loaded() assert self._molecule_transformer_model is not None assert self._molecule_transformer_tokenizer is not None outputs: list[np.ndarray] = [] batch_size = max(self.molecule_transformer_batch_size, 1) with torch.no_grad(): for start in range(0, len(smiles_values), batch_size): batch = smiles_values[start : start + batch_size] encoded = self._molecule_transformer_tokenizer( batch, padding=True, truncation=True, max_length=self.molecule_transformer_max_length, return_tensors="pt", ) encoded = {key: value.to(self._molecule_transformer_device) for key, value in encoded.items()} hidden = self._molecule_transformer_model(**encoded).last_hidden_state mask = encoded["attention_mask"].unsqueeze(-1) pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) outputs.append(pooled.detach().cpu().to(torch.float32).numpy()) return np.concatenate(outputs, axis=0).astype(np.float32) def build_molecule_feature_matrix(self, smiles_values: list[str]) -> np.ndarray: transformer_matrix = self._encode_molecule_transformer_batch(smiles_values) rows: list[np.ndarray] = [] for idx, smiles in enumerate(smiles_values): normalized = standardize_smiles_v2(smiles) or smiles mol = Chem.MolFromSmiles(normalized) if mol is None: raise ValueError(f"Could not parse SMILES: {normalized}") bit_blocks: list[np.ndarray] = [ _morgan_bits_from_mol(mol, radius=int(radius), n_bits=self.fingerprint_bits, use_chirality=self.use_chirality) for radius in self.fingerprint_radii ] if self.use_maccs: bit_blocks.append(_maccs_bits_from_mol(mol)) output_blocks: list[np.ndarray] = [np.concatenate(bit_blocks, axis=0).astype(np.float32)] if self.use_rdkit_descriptors and self.descriptor_names: dense = _molecule_descriptor_vector(mol, names=self.descriptor_names) if self.descriptor_mean is not None and self.descriptor_std is not None: dense = (dense - self.descriptor_mean) / self.descriptor_std output_blocks.append(dense.astype(np.float32)) if transformer_matrix is not None: output_blocks.append(np.asarray(transformer_matrix[idx], dtype=np.float32)) rows.append(np.concatenate(output_blocks, axis=0).astype(np.float32)) return np.stack(rows, axis=0) def _load_sentence_transformer(model_name: str) -> SentenceTransformer: dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 with _silent_imports(): encoder = SentenceTransformer( model_name, trust_remote_code=True, model_kwargs={"torch_dtype": dtype}, ) if getattr(encoder, "tokenizer", None) is not None: encoder.tokenizer.padding_side = "left" return encoder def _load_feature_spec(cfg: dict[str, Any], metadata: dict[str, Any], checkpoint: dict[str, Any]) -> dict[str, Any]: spec = checkpoint.get("molecule_feature_spec") or metadata.get("molecule_feature_spec") if spec: return spec radii = tuple(int(item) for item in (cfg.get("fingerprint_radii") or [cfg.get("fingerprint_radius", 2)])) return { "fingerprint_radii": list(radii), "fingerprint_bits": int(cfg["fingerprint_bits"]), "use_chirality": bool(cfg.get("use_chirality", False)), "use_maccs": bool(cfg.get("use_maccs", False)), "use_rdkit_descriptors": bool(cfg.get("use_rdkit_descriptors", False)), "descriptor_names": [], "descriptor_mean": None, "descriptor_std": None, "molecule_transformer_model_name": str(cfg.get("molecule_transformer_model_name") or ""), "molecule_transformer_max_length": int(cfg.get("molecule_transformer_max_length", 128) or 128), } def load_compatibility_model(model_dir: str | Path) -> SpaceCompatibilityModel: model_path = Path(model_dir) checkpoint = torch.load(model_path / "best_model.pt", map_location="cpu", weights_only=False) metadata = json.loads((model_path / "training_metadata.json").read_text()) cfg = metadata["config"] feature_spec = _load_feature_spec(cfg, metadata, checkpoint) encoder = _load_sentence_transformer(checkpoint.get("assay_model_name") or cfg["assay_model_name"]) assay_dim = int(checkpoint["model_state_dict"]["assay_proj.weight"].shape[1]) molecule_dim = int(checkpoint["model_state_dict"]["mol_proj.weight"].shape[1]) head = CompatibilityHead( assay_dim=assay_dim, molecule_dim=molecule_dim, projection_dim=int(cfg["projection_dim"]), hidden_dim=int(cfg["hidden_dim"]), dropout=float(cfg["dropout"]), ) load_result = head.load_state_dict(checkpoint["model_state_dict"], strict=False) allowed_missing = {"mol_norm.weight", "mol_norm.bias"} unexpected = set(load_result.unexpected_keys) missing = set(load_result.missing_keys) if unexpected or (missing - allowed_missing): raise RuntimeError( f"Checkpoint mismatch: unexpected={sorted(unexpected)} missing={sorted(missing)}" ) return SpaceCompatibilityModel( assay_encoder=encoder, compatibility_head=head, assay_task_description=checkpoint.get("assay_task_description") or cfg.get("assay_task_description", DEFAULT_ASSAY_TASK), fingerprint_radii=tuple(int(item) for item in feature_spec.get("fingerprint_radii") or [2]), fingerprint_bits=int(feature_spec.get("fingerprint_bits", cfg.get("fingerprint_bits", 2048))), use_chirality=bool(feature_spec.get("use_chirality", cfg.get("use_chirality", False))), use_maccs=bool(feature_spec.get("use_maccs", cfg.get("use_maccs", False))), use_rdkit_descriptors=bool(feature_spec.get("use_rdkit_descriptors", cfg.get("use_rdkit_descriptors", False))), descriptor_names=tuple(feature_spec.get("descriptor_names") or ()), descriptor_mean=np.array(feature_spec["descriptor_mean"], dtype=np.float32) if feature_spec.get("descriptor_mean") is not None else None, descriptor_std=np.array(feature_spec["descriptor_std"], dtype=np.float32) if feature_spec.get("descriptor_std") is not None else None, molecule_transformer_model_name=str(feature_spec.get("molecule_transformer_model_name") or cfg.get("molecule_transformer_model_name") or ""), molecule_transformer_batch_size=int(cfg.get("molecule_transformer_batch_size", 128) or 128), molecule_transformer_max_length=int(feature_spec.get("molecule_transformer_max_length") or cfg.get("molecule_transformer_max_length", 128) or 128), use_assay_metadata_features=bool(cfg.get("use_assay_metadata_features", False)), assay_metadata_dim=int(cfg.get("assay_metadata_dim", 0) or 0), ) @lru_cache(maxsize=1) def load_compatibility_model_from_hub(model_repo_id: str) -> SpaceCompatibilityModel: with _silent_imports(): model_dir = snapshot_download( repo_id=model_repo_id, repo_type="model", allow_patterns=["best_model.pt", "training_metadata.json", "README.md"], ) return load_compatibility_model(model_dir) def rank_compounds( model: SpaceCompatibilityModel, *, assay_text: str, smiles_list: list[str], top_k: int | None = None, ) -> list[dict[str, Any]]: if not smiles_list: return [] assay_features = model._build_assay_feature_array(assay_text) assay_tensor = torch.from_numpy(assay_features.astype(np.float32)).unsqueeze(0) valid_items: list[tuple[str, str]] = [] invalid_items: list[dict[str, Any]] = [] for raw_smiles in smiles_list: standardized = standardize_smiles_v2(raw_smiles) if standardized is None: invalid_items.append( { "input_smiles": raw_smiles, "canonical_smiles": None, "smiles_hash": None, "score": None, "valid": False, "error": "invalid_smiles", } ) continue valid_items.append((raw_smiles, standardized)) ranked_items: list[dict[str, Any]] = [] if valid_items: feature_matrix = model.build_molecule_feature_matrix([item[1] for item in valid_items]) candidate_tensor = torch.from_numpy(feature_matrix).unsqueeze(0) with torch.no_grad(): logits, _, _ = model.compatibility_head.score_candidates( assay_tensor.to(dtype=torch.float32), candidate_tensor.to(dtype=torch.float32), ) scores = logits.squeeze(0).cpu().numpy().tolist() for (raw_smiles, canonical), score in zip(valid_items, scores, strict=True): ranked_items.append( { "input_smiles": raw_smiles, "canonical_smiles": canonical, "smiles_hash": smiles_sha256(canonical), "score": float(score), "valid": True, } ) ranked_items.sort(key=lambda item: item["score"], reverse=True) if top_k is not None and top_k > 0: ranked_items = ranked_items[:top_k] return ranked_items + invalid_items