BioAssayAlign-Qwen3-Embedding-0.6B-Compatibility / bioassayalign_compatibility.py
lighteternal's picture
Upload bioassayalign_compatibility.py with huggingface_hub
8a63cbb verified
from __future__ import annotations
import hashlib
import json
import os
import re
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
from rdkit import Chem, DataStructs, RDLogger
from rdkit.Chem import AllChem, Crippen, Descriptors, Lipinski, MACCSkeys, rdMolDescriptors
from rdkit.Chem.MolStandardize import rdMolStandardize
from sentence_transformers import SentenceTransformer
from torch import nn
from transformers.utils import logging as transformers_logging
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
disable_progress_bars()
transformers_logging.set_verbosity_error()
RDLogger.DisableLog("rdApp.*")
DEFAULT_ASSAY_TASK = (
"Given a bioassay description and metadata, represent the assay for ranking compatible small molecules."
)
SECTION_ORDER = [
"ASSAY_TITLE",
"DESCRIPTION",
"ORGANISM",
"READOUT",
"ASSAY_FORMAT",
"ASSAY_TYPE",
"TARGET_UNIPROT",
]
ASSAY_SECTION_RE = re.compile(
r"\[(ASSAY_TITLE|DESCRIPTION|ORGANISM|READOUT|ASSAY_FORMAT|ASSAY_TYPE|TARGET_UNIPROT)\]\n"
)
ORGANISM_ALIASES = {
"9606": "homo_sapiens",
"10090": "mus_musculus",
"10116": "rattus_norvegicus",
"4932": "saccharomyces_cerevisiae",
}
DEFAULT_DESCRIPTOR_NAMES = (
"mol_wt",
"logp",
"tpsa",
"heavy_atoms",
"hbd",
"hba",
"rot_bonds",
"ring_count",
"aromatic_rings",
"aliphatic_rings",
"saturated_rings",
"fraction_csp3",
"heteroatoms",
"amide_bonds",
"fragments",
"formal_charge",
"max_atomic_num",
"metal_atom_count",
"halogen_count",
"nitrogen_count",
"oxygen_count",
"sulfur_count",
"phosphorus_count",
"fluorine_count",
"chlorine_count",
"bromine_count",
"iodine_count",
"aromatic_atom_count",
"spiro_atoms",
"bridgehead_atoms",
)
ORGANIC_LIKE_ATOMIC_NUMBERS = {1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53}
@dataclass
class AssayQuery:
title: str = ""
description: str = ""
organism: str = ""
readout: str = ""
assay_format: str = ""
assay_type: str = ""
target_uniprot: list[str] | None = None
def serialize_assay_query(query: AssayQuery) -> str:
targets = ", ".join(query.target_uniprot or [])
values = {
"ASSAY_TITLE": query.title.strip(),
"DESCRIPTION": query.description.strip(),
"ORGANISM": query.organism.strip(),
"READOUT": query.readout.strip(),
"ASSAY_FORMAT": query.assay_format.strip(),
"ASSAY_TYPE": query.assay_type.strip(),
"TARGET_UNIPROT": targets.strip(),
}
return "\n\n".join(f"[{key}]\n{values[key]}" for key in SECTION_ORDER)
def _format_assay_query(assay_text: str, task_description: str) -> str:
return f"Instruct: {task_description.strip()}\nQuery: {assay_text.strip()}"
def _parse_assay_sections(assay_text: str) -> dict[str, str]:
sections = {key: "" for key in SECTION_ORDER}
parts = ASSAY_SECTION_RE.split(assay_text)
for idx in range(1, len(parts), 2):
key = parts[idx]
value = parts[idx + 1] if idx + 1 < len(parts) else ""
if key in sections:
sections[key] = value.strip()
return sections
def _normalize_metadata_token(value: str) -> str:
return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_")
def _normalize_organism_token(value: str) -> str:
raw = value.strip()
if not raw:
return ""
aliased = ORGANISM_ALIASES.get(raw, raw)
return _normalize_metadata_token(aliased)
def _hash_bucket(value: str, dim: int) -> int:
return abs(hash(value)) % max(dim, 1)
def _assay_metadata_vector(assay_text: str, *, dim: int) -> np.ndarray:
if dim <= 0:
return np.zeros((0,), dtype=np.float32)
sections = _parse_assay_sections(assay_text)
tokens: list[str] = []
organism = _normalize_organism_token(sections.get("ORGANISM", ""))
if organism:
tokens.append(f"organism:{organism}")
for key in ("READOUT", "ASSAY_FORMAT", "ASSAY_TYPE"):
value = _normalize_metadata_token(sections.get(key, ""))
if value:
tokens.append(f"{key.lower()}:{value}")
for target in sections.get("TARGET_UNIPROT", "").split(","):
token = target.strip().upper()
if token:
tokens.append(f"target:{token}")
vec = np.zeros((dim,), dtype=np.float32)
for token in tokens:
vec[_hash_bucket(token, dim)] += 1.0
norm = float(np.linalg.norm(vec))
if norm > 0:
vec /= norm
return vec
@lru_cache(maxsize=1_000_000)
def _standardize_smiles_v2_cached(smiles: str) -> str | None:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
try:
mol = rdMolStandardize.Cleanup(mol)
mol = rdMolStandardize.FragmentParent(mol)
mol = rdMolStandardize.Uncharger().uncharge(mol)
mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol)
Chem.SanitizeMol(mol)
except Exception:
return None
if mol.GetNumHeavyAtoms() < 2:
return None
standardized = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True)
if not standardized or "." in standardized:
return None
return standardized
def standardize_smiles_v2(smiles: str | None) -> str | None:
if not smiles:
return None
token = smiles.strip()
if not token:
return None
return _standardize_smiles_v2_cached(token)
def smiles_sha256(smiles: str) -> str:
return hashlib.sha256(smiles.encode("utf-8")).hexdigest()
def _count_atomic_nums(mol) -> dict[int, int]:
counts: dict[int, int] = {}
for atom in mol.GetAtoms():
atomic_num = int(atom.GetAtomicNum())
counts[atomic_num] = counts.get(atomic_num, 0) + 1
return counts
def _morgan_bits_from_mol(mol, *, radius: int, n_bits: int, use_chirality: bool) -> np.ndarray:
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits, useChirality=use_chirality)
arr = np.zeros((n_bits,), dtype=np.uint8)
DataStructs.ConvertToNumpyArray(fp, arr)
return arr
def _maccs_bits_from_mol(mol) -> np.ndarray:
fp = MACCSkeys.GenMACCSKeys(mol)
arr = np.zeros((fp.GetNumBits(),), dtype=np.uint8)
DataStructs.ConvertToNumpyArray(fp, arr)
return arr
def _molecule_descriptor_vector(mol, *, names: tuple[str, ...] = DEFAULT_DESCRIPTOR_NAMES) -> np.ndarray:
counts = _count_atomic_nums(mol)
fragments = Chem.GetMolFrags(mol)
formal_charge = sum(int(atom.GetFormalCharge()) for atom in mol.GetAtoms())
max_atomic_num = max(counts) if counts else 0
metal_atom_count = sum(
count for atomic_num, count in counts.items() if atomic_num not in ORGANIC_LIKE_ATOMIC_NUMBERS
)
halogen_count = sum(counts.get(item, 0) for item in (9, 17, 35, 53))
aromatic_atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())
values = {
"mol_wt": float(Descriptors.MolWt(mol)),
"logp": float(Crippen.MolLogP(mol)),
"tpsa": float(rdMolDescriptors.CalcTPSA(mol)),
"heavy_atoms": float(mol.GetNumHeavyAtoms()),
"hbd": float(Lipinski.NumHDonors(mol)),
"hba": float(Lipinski.NumHAcceptors(mol)),
"rot_bonds": float(Lipinski.NumRotatableBonds(mol)),
"ring_count": float(rdMolDescriptors.CalcNumRings(mol)),
"aromatic_rings": float(rdMolDescriptors.CalcNumAromaticRings(mol)),
"aliphatic_rings": float(rdMolDescriptors.CalcNumAliphaticRings(mol)),
"saturated_rings": float(rdMolDescriptors.CalcNumSaturatedRings(mol)),
"fraction_csp3": float(rdMolDescriptors.CalcFractionCSP3(mol)),
"heteroatoms": float(rdMolDescriptors.CalcNumHeteroatoms(mol)),
"amide_bonds": float(rdMolDescriptors.CalcNumAmideBonds(mol)),
"fragments": float(len(fragments)),
"formal_charge": float(formal_charge),
"max_atomic_num": float(max_atomic_num),
"metal_atom_count": float(metal_atom_count),
"halogen_count": float(halogen_count),
"nitrogen_count": float(counts.get(7, 0)),
"oxygen_count": float(counts.get(8, 0)),
"sulfur_count": float(counts.get(16, 0)),
"phosphorus_count": float(counts.get(15, 0)),
"fluorine_count": float(counts.get(9, 0)),
"chlorine_count": float(counts.get(17, 0)),
"bromine_count": float(counts.get(35, 0)),
"iodine_count": float(counts.get(53, 0)),
"aromatic_atom_count": float(aromatic_atom_count),
"spiro_atoms": float(rdMolDescriptors.CalcNumSpiroAtoms(mol)),
"bridgehead_atoms": float(rdMolDescriptors.CalcNumBridgeheadAtoms(mol)),
}
return np.asarray([values[name] for name in names], dtype=np.float32)
class CompatibilityHead(nn.Module):
def __init__(self, *, assay_dim: int, molecule_dim: int, projection_dim: int, hidden_dim: int, dropout: float) -> None:
super().__init__()
self.assay_norm = nn.LayerNorm(assay_dim)
self.assay_proj = nn.Linear(assay_dim, projection_dim)
self.mol_norm = nn.LayerNorm(molecule_dim)
self.mol_proj = nn.Linear(molecule_dim, projection_dim, bias=False)
self.score_mlp = nn.Sequential(
nn.Linear(projection_dim * 4, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
)
self.dot_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
def encode_assay(self, assay_features: torch.Tensor) -> torch.Tensor:
vec = self.assay_proj(self.assay_norm(assay_features))
return F.normalize(vec, p=2, dim=-1)
def encode_molecule(self, molecule_features: torch.Tensor) -> torch.Tensor:
vec = self.mol_proj(self.mol_norm(molecule_features))
return F.normalize(vec, p=2, dim=-1)
def score_candidates(self, assay_features: torch.Tensor, candidate_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assay_vec = self.encode_assay(assay_features)
mol_vec = self.encode_molecule(candidate_features)
assay_expand = assay_vec.unsqueeze(1).expand(-1, mol_vec.shape[1], -1)
dot_scores = (assay_expand * mol_vec).sum(dim=-1)
mlp_input = torch.cat(
[
assay_expand,
mol_vec,
assay_expand * mol_vec,
torch.abs(assay_expand - mol_vec),
],
dim=-1,
)
mlp_scores = self.score_mlp(mlp_input).squeeze(-1)
logits = dot_scores * self.dot_scale + mlp_scores
return logits, assay_vec, mol_vec
def score_pairs(self, assay_features: torch.Tensor, molecule_features: torch.Tensor) -> torch.Tensor:
assay_vec = self.encode_assay(assay_features)
mol_vec = self.encode_molecule(molecule_features)
dot_scores = (assay_vec * mol_vec).sum(dim=-1)
mlp_input = torch.cat(
[assay_vec, mol_vec, assay_vec * mol_vec, torch.abs(assay_vec - mol_vec)],
dim=-1,
)
mlp_scores = self.score_mlp(mlp_input).squeeze(-1)
return dot_scores * self.dot_scale + mlp_scores
class BioAssayAlignCompatibilityModel:
def __init__(
self,
assay_encoder: SentenceTransformer,
compatibility_head: CompatibilityHead,
*,
assay_task_description: str,
fingerprint_radii: tuple[int, ...],
fingerprint_bits: int,
use_chirality: bool,
use_maccs: bool,
use_rdkit_descriptors: bool,
descriptor_names: tuple[str, ...],
descriptor_mean: np.ndarray | None,
descriptor_std: np.ndarray | None,
use_assay_metadata_features: bool,
assay_metadata_dim: int,
) -> None:
self.assay_encoder = assay_encoder
self.compatibility_head = compatibility_head.eval()
self.assay_task_description = assay_task_description
self.fingerprint_radii = fingerprint_radii
self.fingerprint_bits = fingerprint_bits
self.use_chirality = use_chirality
self.use_maccs = use_maccs
self.use_rdkit_descriptors = use_rdkit_descriptors
self.descriptor_names = descriptor_names
self.descriptor_mean = descriptor_mean
self.descriptor_std = descriptor_std
self.use_assay_metadata_features = use_assay_metadata_features
self.assay_metadata_dim = assay_metadata_dim
def _build_assay_feature_array(self, assay_text: str) -> np.ndarray:
query = _format_assay_query(assay_text, self.assay_task_description)
assay_features = self.assay_encoder.encode(
[query],
batch_size=1,
normalize_embeddings=True,
show_progress_bar=False,
convert_to_numpy=True,
)[0].astype(np.float32)
if self.use_assay_metadata_features and self.assay_metadata_dim > 0:
metadata_vec = _assay_metadata_vector(assay_text, dim=self.assay_metadata_dim)
assay_features = np.concatenate([assay_features, metadata_vec.astype(np.float32)], axis=0)
return assay_features
def build_molecule_feature_matrix(self, smiles_values: list[str]) -> np.ndarray:
rows: list[np.ndarray] = []
for smiles in smiles_values:
rows.append(
_smiles_to_molecule_features(
smiles,
radii=self.fingerprint_radii,
n_bits=self.fingerprint_bits,
use_chirality=self.use_chirality,
use_maccs=self.use_maccs,
use_rdkit_descriptors=self.use_rdkit_descriptors,
descriptor_names=self.descriptor_names,
descriptor_mean=self.descriptor_mean,
descriptor_std=self.descriptor_std,
)
)
return np.stack(rows, axis=0).astype(np.float32)
def score(self, assay_text: str, smiles: str) -> float:
assay_features = self._build_assay_feature_array(assay_text)
molecule_features = self.build_molecule_feature_matrix([smiles])[0]
assay_tensor = torch.from_numpy(assay_features).unsqueeze(0)
molecule_tensor = torch.from_numpy(molecule_features).unsqueeze(0)
with torch.no_grad():
score = self.compatibility_head.score_pairs(assay_tensor, molecule_tensor)
return float(score.item())
def _load_sentence_transformer(model_name: str):
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
encoder = SentenceTransformer(
model_name,
trust_remote_code=True,
model_kwargs={"torch_dtype": dtype},
)
if getattr(encoder, "tokenizer", None) is not None:
encoder.tokenizer.padding_side = "left"
return encoder
def _load_feature_spec(cfg: dict[str, Any], metadata: dict[str, Any], checkpoint: dict[str, Any]) -> dict[str, Any]:
spec = checkpoint.get("molecule_feature_spec") or metadata.get("molecule_feature_spec")
if spec:
return spec
radii = tuple(int(item) for item in (cfg.get("fingerprint_radii") or [cfg.get("fingerprint_radius", 2)]))
return {
"fingerprint_radii": list(radii),
"fingerprint_bits": int(cfg["fingerprint_bits"]),
"use_chirality": bool(cfg.get("use_chirality", False)),
"use_maccs": bool(cfg.get("use_maccs", False)),
"use_rdkit_descriptors": bool(cfg.get("use_rdkit_descriptors", False)),
"descriptor_names": [],
"descriptor_mean": None,
"descriptor_std": None,
}
def _smiles_to_molecule_features(
smiles: str,
*,
radii: tuple[int, ...],
n_bits: int,
use_chirality: bool,
use_maccs: bool,
use_rdkit_descriptors: bool,
descriptor_names: tuple[str, ...],
descriptor_mean: np.ndarray | None,
descriptor_std: np.ndarray | None,
) -> np.ndarray:
normalized = standardize_smiles_v2(smiles) or smiles
mol = Chem.MolFromSmiles(normalized)
if mol is None:
raise ValueError(f"Could not parse SMILES: {normalized}")
bit_blocks: list[np.ndarray] = [
_morgan_bits_from_mol(mol, radius=int(radius), n_bits=n_bits, use_chirality=use_chirality)
for radius in radii
]
if use_maccs:
bit_blocks.append(_maccs_bits_from_mol(mol))
output_blocks: list[np.ndarray] = [np.concatenate(bit_blocks, axis=0).astype(np.float32)]
if use_rdkit_descriptors and descriptor_names:
dense = _molecule_descriptor_vector(mol, names=descriptor_names)
if descriptor_mean is not None and descriptor_std is not None:
dense = (dense - descriptor_mean) / descriptor_std
output_blocks.append(dense.astype(np.float32))
return np.concatenate(output_blocks, axis=0).astype(np.float32)
def load_compatibility_model(model_dir: str | Path) -> BioAssayAlignCompatibilityModel:
model_path = Path(model_dir)
checkpoint = torch.load(model_path / "best_model.pt", map_location="cpu", weights_only=False)
metadata = json.loads((model_path / "training_metadata.json").read_text())
cfg = metadata["config"]
feature_spec = _load_feature_spec(cfg, metadata, checkpoint)
encoder = _load_sentence_transformer(checkpoint.get("assay_model_name") or cfg["assay_model_name"])
assay_dim = int(checkpoint["model_state_dict"]["assay_proj.weight"].shape[1])
molecule_dim = int(checkpoint["model_state_dict"]["mol_proj.weight"].shape[1])
head = CompatibilityHead(
assay_dim=assay_dim,
molecule_dim=molecule_dim,
projection_dim=int(cfg["projection_dim"]),
hidden_dim=int(cfg["hidden_dim"]),
dropout=float(cfg["dropout"]),
)
load_result = head.load_state_dict(checkpoint["model_state_dict"], strict=False)
allowed_missing = {"mol_norm.weight", "mol_norm.bias"}
unexpected = set(load_result.unexpected_keys)
missing = set(load_result.missing_keys)
if unexpected or (missing - allowed_missing):
raise RuntimeError(
"Compatibility checkpoint load mismatch: "
f"unexpected={sorted(unexpected)} missing={sorted(missing)}"
)
return BioAssayAlignCompatibilityModel(
assay_encoder=encoder,
compatibility_head=head,
assay_task_description=checkpoint.get("assay_task_description") or cfg["assay_task_description"],
fingerprint_radii=tuple(int(item) for item in feature_spec.get("fingerprint_radii") or [2]),
fingerprint_bits=int(feature_spec.get("fingerprint_bits", cfg.get("fingerprint_bits", 2048))),
use_chirality=bool(feature_spec.get("use_chirality", cfg.get("use_chirality", False))),
use_maccs=bool(feature_spec.get("use_maccs", False)),
use_rdkit_descriptors=bool(feature_spec.get("use_rdkit_descriptors", False)),
descriptor_names=tuple(feature_spec.get("descriptor_names") or ()),
descriptor_mean=np.array(feature_spec["descriptor_mean"], dtype=np.float32)
if feature_spec.get("descriptor_mean") is not None
else None,
descriptor_std=np.array(feature_spec["descriptor_std"], dtype=np.float32)
if feature_spec.get("descriptor_std") is not None
else None,
use_assay_metadata_features=bool(cfg.get("use_assay_metadata_features", False)),
assay_metadata_dim=int(cfg.get("assay_metadata_dim", 0) or 0),
)
def load_compatibility_model_from_hub(repo_id: str) -> BioAssayAlignCompatibilityModel:
snapshot_path = snapshot_download(
repo_id=repo_id,
repo_type="model",
allow_patterns=["best_model.pt", "training_metadata.json"],
)
return load_compatibility_model(snapshot_path)
def rank_compounds(
model: BioAssayAlignCompatibilityModel,
*,
assay_text: str,
smiles_list: list[str],
top_k: int | None = None,
) -> list[dict[str, Any]]:
if not smiles_list:
return []
assay_features = model._build_assay_feature_array(assay_text)
assay_tensor = torch.from_numpy(assay_features.astype(np.float32)).unsqueeze(0)
valid_items: list[tuple[str, str]] = []
invalid_items: list[dict[str, Any]] = []
for raw_smiles in smiles_list:
standardized = standardize_smiles_v2(raw_smiles)
if standardized is None:
invalid_items.append(
{
"input_smiles": raw_smiles,
"canonical_smiles": None,
"smiles_hash": None,
"score": None,
"valid": False,
"error": "invalid_smiles",
}
)
continue
valid_items.append((raw_smiles, standardized))
ranked_items: list[dict[str, Any]] = []
if valid_items:
feature_matrix = model.build_molecule_feature_matrix([item[1] for item in valid_items])
candidate_tensor = torch.from_numpy(feature_matrix).unsqueeze(0)
with torch.no_grad():
logits, _, _ = model.compatibility_head.score_candidates(
assay_tensor.to(dtype=torch.float32),
candidate_tensor.to(dtype=torch.float32),
)
scores = logits.squeeze(0).cpu().numpy().tolist()
for (raw_smiles, canonical), score in zip(valid_items, scores, strict=True):
ranked_items.append(
{
"input_smiles": raw_smiles,
"canonical_smiles": canonical,
"smiles_hash": smiles_sha256(canonical),
"score": float(score),
"valid": True,
}
)
ranked_items.sort(key=lambda item: item["score"], reverse=True)
if top_k is not None and top_k > 0:
ranked_items = ranked_items[:top_k]
return ranked_items + invalid_items
def list_softmax_scores(scores: list[float], temperature: float = 1.0) -> list[float]:
values = np.asarray(scores, dtype=np.float32) / max(float(temperature), 1e-6)
values = values - values.max()
probs = np.exp(values)
probs = probs / probs.sum()
return probs.tolist()