lighteternal's picture
Polish UX, examples, and result explainability
f1158c7 verified
from __future__ import annotations
import contextlib
import hashlib
import io
import json
import os
import re
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
from rdkit import Chem, DataStructs, RDLogger
from rdkit.Chem import AllChem, Crippen, Descriptors, Lipinski, MACCSkeys, rdMolDescriptors
from rdkit.Chem.MolStandardize import rdMolStandardize
from sentence_transformers import SentenceTransformer
from torch import nn
from transformers import AutoModel, AutoTokenizer
from transformers.utils import logging as transformers_logging
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
disable_progress_bars()
transformers_logging.set_verbosity_error()
RDLogger.DisableLog("rdApp.*")
DEFAULT_ASSAY_TASK = (
"Given a bioassay description and metadata, represent the assay for ranking compatible small molecules."
)
DEFAULT_DESCRIPTOR_NAMES = (
"mol_wt",
"logp",
"tpsa",
"heavy_atoms",
"hbd",
"hba",
"rot_bonds",
"ring_count",
"aromatic_rings",
"aliphatic_rings",
"saturated_rings",
"fraction_csp3",
"heteroatoms",
"amide_bonds",
"fragments",
"formal_charge",
"max_atomic_num",
"metal_atom_count",
"halogen_count",
"nitrogen_count",
"oxygen_count",
"sulfur_count",
"phosphorus_count",
"fluorine_count",
"chlorine_count",
"bromine_count",
"iodine_count",
"aromatic_atom_count",
"spiro_atoms",
"bridgehead_atoms",
)
ORGANIC_LIKE_ATOMIC_NUMBERS = {1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53}
SECTION_ORDER = [
"ASSAY_TITLE",
"DESCRIPTION",
"ORGANISM",
"READOUT",
"ASSAY_FORMAT",
"ASSAY_TYPE",
"TARGET_UNIPROT",
]
ASSAY_SECTION_RE = re.compile(r"\[(ASSAY_TITLE|DESCRIPTION|ORGANISM|READOUT|ASSAY_FORMAT|ASSAY_TYPE|TARGET_UNIPROT)\]\n")
ORGANISM_ALIASES = {
"9606": "homo_sapiens",
"10090": "mus_musculus",
"10116": "rattus_norvegicus",
"4932": "saccharomyces_cerevisiae",
}
@dataclass
class AssayQuery:
title: str = ""
description: str = ""
organism: str = ""
readout: str = ""
assay_format: str = ""
assay_type: str = ""
target_uniprot: list[str] | None = None
def smiles_sha256(smiles: str) -> str:
return hashlib.sha256(smiles.encode("utf-8")).hexdigest()
@contextlib.contextmanager
def _silent_imports():
buffer = io.StringIO()
with contextlib.redirect_stdout(buffer), contextlib.redirect_stderr(buffer):
yield
@lru_cache(maxsize=1_000_000)
def _standardize_smiles_v2_cached(smiles: str) -> str | None:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
try:
mol = rdMolStandardize.Cleanup(mol)
mol = rdMolStandardize.FragmentParent(mol)
mol = rdMolStandardize.Uncharger().uncharge(mol)
mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol)
Chem.SanitizeMol(mol)
except Exception:
return None
if mol.GetNumHeavyAtoms() < 2:
return None
standardized = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True)
if not standardized or "." in standardized:
return None
return standardized
def standardize_smiles_v2(smiles: str | None) -> str | None:
if not smiles:
return None
token = smiles.strip()
if not token:
return None
return _standardize_smiles_v2_cached(token)
def serialize_assay_query(query: AssayQuery) -> str:
targets = ", ".join(query.target_uniprot or [])
values = {
"ASSAY_TITLE": query.title.strip(),
"DESCRIPTION": query.description.strip(),
"ORGANISM": query.organism.strip(),
"READOUT": query.readout.strip(),
"ASSAY_FORMAT": query.assay_format.strip(),
"ASSAY_TYPE": query.assay_type.strip(),
"TARGET_UNIPROT": targets.strip(),
}
return "\n\n".join(f"[{key}]\n{values[key]}" for key in SECTION_ORDER)
def _parse_assay_sections(assay_text: str) -> dict[str, str]:
sections = {key: "" for key in SECTION_ORDER}
parts = ASSAY_SECTION_RE.split(assay_text)
for idx in range(1, len(parts), 2):
key = parts[idx]
value = parts[idx + 1] if idx + 1 < len(parts) else ""
if key in sections:
sections[key] = value.strip()
return sections
def _hash_bucket(value: str, dim: int) -> int:
return abs(hash(value)) % max(dim, 1)
def _normalize_metadata_token(value: str) -> str:
return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_")
def _normalize_organism_token(value: str) -> str:
raw = value.strip()
if not raw:
return ""
aliased = ORGANISM_ALIASES.get(raw, raw)
return _normalize_metadata_token(aliased)
def _assay_metadata_vector(assay_text: str, *, dim: int) -> np.ndarray:
if dim <= 0:
return np.zeros((0,), dtype=np.float32)
sections = _parse_assay_sections(assay_text)
tokens: list[str] = []
organism = _normalize_organism_token(sections.get("ORGANISM", ""))
if organism:
tokens.append(f"organism:{organism}")
for key in ("READOUT", "ASSAY_FORMAT", "ASSAY_TYPE"):
value = _normalize_metadata_token(sections.get(key, ""))
if value:
tokens.append(f"{key.lower()}:{value}")
for target in sections.get("TARGET_UNIPROT", "").split(","):
token = target.strip().upper()
if token:
tokens.append(f"target:{token}")
vec = np.zeros((dim,), dtype=np.float32)
for token in tokens:
vec[_hash_bucket(token, dim)] += 1.0
norm = float(np.linalg.norm(vec))
if norm > 0:
vec /= norm
return vec
def _morgan_bits_from_mol(mol, *, radius: int, n_bits: int, use_chirality: bool) -> np.ndarray:
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits, useChirality=use_chirality)
arr = np.zeros((n_bits,), dtype=np.uint8)
DataStructs.ConvertToNumpyArray(fp, arr)
return arr
def _maccs_bits_from_mol(mol) -> np.ndarray:
fp = MACCSkeys.GenMACCSKeys(mol)
arr = np.zeros((fp.GetNumBits(),), dtype=np.uint8)
DataStructs.ConvertToNumpyArray(fp, arr)
return arr
def _count_atomic_nums(mol) -> dict[int, int]:
counts: dict[int, int] = {}
for atom in mol.GetAtoms():
atomic_num = int(atom.GetAtomicNum())
counts[atomic_num] = counts.get(atomic_num, 0) + 1
return counts
def _molecule_descriptor_vector(mol, *, names: tuple[str, ...] = DEFAULT_DESCRIPTOR_NAMES) -> np.ndarray:
counts = _count_atomic_nums(mol)
fragments = Chem.GetMolFrags(mol)
formal_charge = sum(int(atom.GetFormalCharge()) for atom in mol.GetAtoms())
max_atomic_num = max(counts) if counts else 0
metal_atom_count = sum(count for atomic_num, count in counts.items() if atomic_num not in ORGANIC_LIKE_ATOMIC_NUMBERS)
halogen_count = sum(counts.get(item, 0) for item in (9, 17, 35, 53))
aromatic_atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())
values = {
"mol_wt": float(Descriptors.MolWt(mol)),
"logp": float(Crippen.MolLogP(mol)),
"tpsa": float(rdMolDescriptors.CalcTPSA(mol)),
"heavy_atoms": float(mol.GetNumHeavyAtoms()),
"hbd": float(Lipinski.NumHDonors(mol)),
"hba": float(Lipinski.NumHAcceptors(mol)),
"rot_bonds": float(Lipinski.NumRotatableBonds(mol)),
"ring_count": float(rdMolDescriptors.CalcNumRings(mol)),
"aromatic_rings": float(rdMolDescriptors.CalcNumAromaticRings(mol)),
"aliphatic_rings": float(rdMolDescriptors.CalcNumAliphaticRings(mol)),
"saturated_rings": float(rdMolDescriptors.CalcNumSaturatedRings(mol)),
"fraction_csp3": float(rdMolDescriptors.CalcFractionCSP3(mol)),
"heteroatoms": float(rdMolDescriptors.CalcNumHeteroatoms(mol)),
"amide_bonds": float(rdMolDescriptors.CalcNumAmideBonds(mol)),
"fragments": float(len(fragments)),
"formal_charge": float(formal_charge),
"max_atomic_num": float(max_atomic_num),
"metal_atom_count": float(metal_atom_count),
"halogen_count": float(halogen_count),
"nitrogen_count": float(counts.get(7, 0)),
"oxygen_count": float(counts.get(8, 0)),
"sulfur_count": float(counts.get(16, 0)),
"phosphorus_count": float(counts.get(15, 0)),
"fluorine_count": float(counts.get(9, 0)),
"chlorine_count": float(counts.get(17, 0)),
"bromine_count": float(counts.get(35, 0)),
"iodine_count": float(counts.get(53, 0)),
"aromatic_atom_count": float(aromatic_atom_count),
"spiro_atoms": float(rdMolDescriptors.CalcNumSpiroAtoms(mol)),
"bridgehead_atoms": float(rdMolDescriptors.CalcNumBridgeheadAtoms(mol)),
}
return np.array([values[name] for name in names], dtype=np.float32)
def molecule_ui_metrics(smiles: str) -> dict[str, float | int]:
canonical = standardize_smiles_v2(smiles) or smiles
mol = Chem.MolFromSmiles(canonical)
if mol is None:
return {
"mol_wt": 0.0,
"logp": 0.0,
"tpsa": 0.0,
"heavy_atoms": 0,
}
return {
"mol_wt": float(Descriptors.MolWt(mol)),
"logp": float(Crippen.MolLogP(mol)),
"tpsa": float(rdMolDescriptors.CalcTPSA(mol)),
"heavy_atoms": int(mol.GetNumHeavyAtoms()),
}
class CompatibilityHead(nn.Module):
def __init__(self, *, assay_dim: int, molecule_dim: int, projection_dim: int, hidden_dim: int, dropout: float) -> None:
super().__init__()
self.assay_norm = nn.LayerNorm(assay_dim)
self.assay_proj = nn.Linear(assay_dim, projection_dim)
self.mol_norm = nn.LayerNorm(molecule_dim)
self.mol_proj = nn.Linear(molecule_dim, projection_dim, bias=False)
self.score_mlp = nn.Sequential(
nn.Linear(projection_dim * 4, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
)
self.dot_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
def encode_assay(self, assay_features: torch.Tensor) -> torch.Tensor:
vec = self.assay_proj(self.assay_norm(assay_features))
return F.normalize(vec, p=2, dim=-1)
def encode_molecule(self, molecule_features: torch.Tensor) -> torch.Tensor:
vec = self.mol_proj(self.mol_norm(molecule_features))
return F.normalize(vec, p=2, dim=-1)
def score_candidates(self, assay_features: torch.Tensor, candidate_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assay_vec = self.encode_assay(assay_features)
mol_vec = self.encode_molecule(candidate_features)
assay_expand = assay_vec.unsqueeze(1).expand(-1, mol_vec.shape[1], -1)
dot_scores = (assay_expand * mol_vec).sum(dim=-1)
mlp_input = torch.cat(
[assay_expand, mol_vec, assay_expand * mol_vec, torch.abs(assay_expand - mol_vec)],
dim=-1,
)
mlp_scores = self.score_mlp(mlp_input).squeeze(-1)
logits = dot_scores * self.dot_scale + mlp_scores
return logits, assay_vec, mol_vec
class SpaceCompatibilityModel:
def __init__(
self,
*,
assay_encoder: SentenceTransformer,
compatibility_head: CompatibilityHead,
assay_task_description: str,
fingerprint_radii: tuple[int, ...],
fingerprint_bits: int,
use_chirality: bool,
use_maccs: bool,
use_rdkit_descriptors: bool,
descriptor_names: tuple[str, ...],
descriptor_mean: np.ndarray | None,
descriptor_std: np.ndarray | None,
molecule_transformer_model_name: str,
molecule_transformer_batch_size: int,
molecule_transformer_max_length: int,
use_assay_metadata_features: bool,
assay_metadata_dim: int,
) -> None:
self.assay_encoder = assay_encoder
self.compatibility_head = compatibility_head.eval()
self.assay_task_description = assay_task_description
self.fingerprint_radii = fingerprint_radii
self.fingerprint_bits = fingerprint_bits
self.use_chirality = use_chirality
self.use_maccs = use_maccs
self.use_rdkit_descriptors = use_rdkit_descriptors
self.descriptor_names = descriptor_names
self.descriptor_mean = descriptor_mean
self.descriptor_std = descriptor_std
self.molecule_transformer_model_name = molecule_transformer_model_name
self.molecule_transformer_batch_size = molecule_transformer_batch_size
self.molecule_transformer_max_length = molecule_transformer_max_length
self.use_assay_metadata_features = use_assay_metadata_features
self.assay_metadata_dim = assay_metadata_dim
self._molecule_transformer_tokenizer = None
self._molecule_transformer_model = None
self._molecule_transformer_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _format_assay_query(self, assay_text: str) -> str:
return f"Instruct: {self.assay_task_description.strip()}\nQuery: {assay_text.strip()}"
def _build_assay_feature_array(self, assay_text: str) -> np.ndarray:
assay_features = self.assay_encoder.encode(
[self._format_assay_query(assay_text)],
batch_size=1,
normalize_embeddings=True,
show_progress_bar=False,
convert_to_numpy=True,
)[0].astype(np.float32)
if self.use_assay_metadata_features and self.assay_metadata_dim > 0:
metadata_vec = _assay_metadata_vector(assay_text, dim=self.assay_metadata_dim)
assay_features = np.concatenate([assay_features, metadata_vec.astype(np.float32)], axis=0)
return assay_features
def _ensure_molecule_transformer_loaded(self) -> None:
if not self.molecule_transformer_model_name or self._molecule_transformer_model is not None:
return
dtype = torch.float16 if self._molecule_transformer_device.type == "cuda" else torch.float32
with _silent_imports():
self._molecule_transformer_tokenizer = AutoTokenizer.from_pretrained(
self.molecule_transformer_model_name,
trust_remote_code=True,
)
self._molecule_transformer_model = AutoModel.from_pretrained(
self.molecule_transformer_model_name,
trust_remote_code=True,
torch_dtype=dtype,
).to(self._molecule_transformer_device)
self._molecule_transformer_model.eval()
def _encode_molecule_transformer_batch(self, smiles_values: list[str]) -> np.ndarray | None:
if not self.molecule_transformer_model_name:
return None
self._ensure_molecule_transformer_loaded()
assert self._molecule_transformer_model is not None
assert self._molecule_transformer_tokenizer is not None
outputs: list[np.ndarray] = []
batch_size = max(self.molecule_transformer_batch_size, 1)
with torch.no_grad():
for start in range(0, len(smiles_values), batch_size):
batch = smiles_values[start : start + batch_size]
encoded = self._molecule_transformer_tokenizer(
batch,
padding=True,
truncation=True,
max_length=self.molecule_transformer_max_length,
return_tensors="pt",
)
encoded = {key: value.to(self._molecule_transformer_device) for key, value in encoded.items()}
hidden = self._molecule_transformer_model(**encoded).last_hidden_state
mask = encoded["attention_mask"].unsqueeze(-1)
pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
outputs.append(pooled.detach().cpu().to(torch.float32).numpy())
return np.concatenate(outputs, axis=0).astype(np.float32)
def build_molecule_feature_matrix(self, smiles_values: list[str]) -> np.ndarray:
transformer_matrix = self._encode_molecule_transformer_batch(smiles_values)
rows: list[np.ndarray] = []
for idx, smiles in enumerate(smiles_values):
normalized = standardize_smiles_v2(smiles) or smiles
mol = Chem.MolFromSmiles(normalized)
if mol is None:
raise ValueError(f"Could not parse SMILES: {normalized}")
bit_blocks: list[np.ndarray] = [
_morgan_bits_from_mol(mol, radius=int(radius), n_bits=self.fingerprint_bits, use_chirality=self.use_chirality)
for radius in self.fingerprint_radii
]
if self.use_maccs:
bit_blocks.append(_maccs_bits_from_mol(mol))
output_blocks: list[np.ndarray] = [np.concatenate(bit_blocks, axis=0).astype(np.float32)]
if self.use_rdkit_descriptors and self.descriptor_names:
dense = _molecule_descriptor_vector(mol, names=self.descriptor_names)
if self.descriptor_mean is not None and self.descriptor_std is not None:
dense = (dense - self.descriptor_mean) / self.descriptor_std
output_blocks.append(dense.astype(np.float32))
if transformer_matrix is not None:
output_blocks.append(np.asarray(transformer_matrix[idx], dtype=np.float32))
rows.append(np.concatenate(output_blocks, axis=0).astype(np.float32))
return np.stack(rows, axis=0)
def _load_sentence_transformer(model_name: str) -> SentenceTransformer:
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
with _silent_imports():
encoder = SentenceTransformer(
model_name,
trust_remote_code=True,
model_kwargs={"torch_dtype": dtype},
)
if getattr(encoder, "tokenizer", None) is not None:
encoder.tokenizer.padding_side = "left"
return encoder
def _load_feature_spec(cfg: dict[str, Any], metadata: dict[str, Any], checkpoint: dict[str, Any]) -> dict[str, Any]:
spec = checkpoint.get("molecule_feature_spec") or metadata.get("molecule_feature_spec")
if spec:
return spec
radii = tuple(int(item) for item in (cfg.get("fingerprint_radii") or [cfg.get("fingerprint_radius", 2)]))
return {
"fingerprint_radii": list(radii),
"fingerprint_bits": int(cfg["fingerprint_bits"]),
"use_chirality": bool(cfg.get("use_chirality", False)),
"use_maccs": bool(cfg.get("use_maccs", False)),
"use_rdkit_descriptors": bool(cfg.get("use_rdkit_descriptors", False)),
"descriptor_names": [],
"descriptor_mean": None,
"descriptor_std": None,
"molecule_transformer_model_name": str(cfg.get("molecule_transformer_model_name") or ""),
"molecule_transformer_max_length": int(cfg.get("molecule_transformer_max_length", 128) or 128),
}
def load_compatibility_model(model_dir: str | Path) -> SpaceCompatibilityModel:
model_path = Path(model_dir)
checkpoint = torch.load(model_path / "best_model.pt", map_location="cpu", weights_only=False)
metadata = json.loads((model_path / "training_metadata.json").read_text())
cfg = metadata["config"]
feature_spec = _load_feature_spec(cfg, metadata, checkpoint)
encoder = _load_sentence_transformer(checkpoint.get("assay_model_name") or cfg["assay_model_name"])
assay_dim = int(checkpoint["model_state_dict"]["assay_proj.weight"].shape[1])
molecule_dim = int(checkpoint["model_state_dict"]["mol_proj.weight"].shape[1])
head = CompatibilityHead(
assay_dim=assay_dim,
molecule_dim=molecule_dim,
projection_dim=int(cfg["projection_dim"]),
hidden_dim=int(cfg["hidden_dim"]),
dropout=float(cfg["dropout"]),
)
load_result = head.load_state_dict(checkpoint["model_state_dict"], strict=False)
allowed_missing = {"mol_norm.weight", "mol_norm.bias"}
unexpected = set(load_result.unexpected_keys)
missing = set(load_result.missing_keys)
if unexpected or (missing - allowed_missing):
raise RuntimeError(
f"Checkpoint mismatch: unexpected={sorted(unexpected)} missing={sorted(missing)}"
)
return SpaceCompatibilityModel(
assay_encoder=encoder,
compatibility_head=head,
assay_task_description=checkpoint.get("assay_task_description") or cfg.get("assay_task_description", DEFAULT_ASSAY_TASK),
fingerprint_radii=tuple(int(item) for item in feature_spec.get("fingerprint_radii") or [2]),
fingerprint_bits=int(feature_spec.get("fingerprint_bits", cfg.get("fingerprint_bits", 2048))),
use_chirality=bool(feature_spec.get("use_chirality", cfg.get("use_chirality", False))),
use_maccs=bool(feature_spec.get("use_maccs", cfg.get("use_maccs", False))),
use_rdkit_descriptors=bool(feature_spec.get("use_rdkit_descriptors", cfg.get("use_rdkit_descriptors", False))),
descriptor_names=tuple(feature_spec.get("descriptor_names") or ()),
descriptor_mean=np.array(feature_spec["descriptor_mean"], dtype=np.float32) if feature_spec.get("descriptor_mean") is not None else None,
descriptor_std=np.array(feature_spec["descriptor_std"], dtype=np.float32) if feature_spec.get("descriptor_std") is not None else None,
molecule_transformer_model_name=str(feature_spec.get("molecule_transformer_model_name") or cfg.get("molecule_transformer_model_name") or ""),
molecule_transformer_batch_size=int(cfg.get("molecule_transformer_batch_size", 128) or 128),
molecule_transformer_max_length=int(feature_spec.get("molecule_transformer_max_length") or cfg.get("molecule_transformer_max_length", 128) or 128),
use_assay_metadata_features=bool(cfg.get("use_assay_metadata_features", False)),
assay_metadata_dim=int(cfg.get("assay_metadata_dim", 0) or 0),
)
@lru_cache(maxsize=1)
def load_compatibility_model_from_hub(model_repo_id: str) -> SpaceCompatibilityModel:
with _silent_imports():
model_dir = snapshot_download(
repo_id=model_repo_id,
repo_type="model",
allow_patterns=["best_model.pt", "training_metadata.json", "README.md"],
)
return load_compatibility_model(model_dir)
def rank_compounds(
model: SpaceCompatibilityModel,
*,
assay_text: str,
smiles_list: list[str],
top_k: int | None = None,
) -> list[dict[str, Any]]:
if not smiles_list:
return []
assay_features = model._build_assay_feature_array(assay_text)
assay_tensor = torch.from_numpy(assay_features.astype(np.float32)).unsqueeze(0)
valid_items: list[tuple[str, str]] = []
invalid_items: list[dict[str, Any]] = []
for raw_smiles in smiles_list:
standardized = standardize_smiles_v2(raw_smiles)
if standardized is None:
invalid_items.append(
{
"input_smiles": raw_smiles,
"canonical_smiles": None,
"smiles_hash": None,
"score": None,
"valid": False,
"error": "invalid_smiles",
}
)
continue
valid_items.append((raw_smiles, standardized))
ranked_items: list[dict[str, Any]] = []
if valid_items:
feature_matrix = model.build_molecule_feature_matrix([item[1] for item in valid_items])
candidate_tensor = torch.from_numpy(feature_matrix).unsqueeze(0)
with torch.no_grad():
logits, _, _ = model.compatibility_head.score_candidates(
assay_tensor.to(dtype=torch.float32),
candidate_tensor.to(dtype=torch.float32),
)
scores = logits.squeeze(0).cpu().numpy().tolist()
for (raw_smiles, canonical), score in zip(valid_items, scores, strict=True):
ranked_items.append(
{
"input_smiles": raw_smiles,
"canonical_smiles": canonical,
"smiles_hash": smiles_sha256(canonical),
"score": float(score),
"valid": True,
}
)
ranked_items.sort(key=lambda item: item["score"], reverse=True)
if top_k is not None and top_k > 0:
ranked_items = ranked_items[:top_k]
return ranked_items + invalid_items