Upload bioassayalign_compatibility.py with huggingface_hub
Browse files- bioassayalign_compatibility.py +283 -148
bioassayalign_compatibility.py
CHANGED
|
@@ -39,7 +39,9 @@ SECTION_ORDER = [
|
|
| 39 |
"ASSAY_TYPE",
|
| 40 |
"TARGET_UNIPROT",
|
| 41 |
]
|
| 42 |
-
ASSAY_SECTION_RE = re.compile(
|
|
|
|
|
|
|
| 43 |
ORGANISM_ALIASES = {
|
| 44 |
"9606": "homo_sapiens",
|
| 45 |
"10090": "mus_musculus",
|
|
@@ -106,6 +108,10 @@ def serialize_assay_query(query: AssayQuery) -> str:
|
|
| 106 |
return "\n\n".join(f"[{key}]\n{values[key]}" for key in SECTION_ORDER)
|
| 107 |
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def _parse_assay_sections(assay_text: str) -> dict[str, str]:
|
| 110 |
sections = {key: "" for key in SECTION_ORDER}
|
| 111 |
parts = ASSAY_SECTION_RE.split(assay_text)
|
|
@@ -219,7 +225,9 @@ def _molecule_descriptor_vector(mol, *, names: tuple[str, ...] = DEFAULT_DESCRIP
|
|
| 219 |
fragments = Chem.GetMolFrags(mol)
|
| 220 |
formal_charge = sum(int(atom.GetFormalCharge()) for atom in mol.GetAtoms())
|
| 221 |
max_atomic_num = max(counts) if counts else 0
|
| 222 |
-
metal_atom_count = sum(
|
|
|
|
|
|
|
| 223 |
halogen_count = sum(counts.get(item, 0) for item in (9, 17, 35, 53))
|
| 224 |
aromatic_atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())
|
| 225 |
values = {
|
|
@@ -258,177 +266,304 @@ def _molecule_descriptor_vector(mol, *, names: tuple[str, ...] = DEFAULT_DESCRIP
|
|
| 258 |
|
| 259 |
|
| 260 |
class CompatibilityHead(nn.Module):
|
| 261 |
-
def __init__(
|
| 262 |
-
self,
|
| 263 |
-
*,
|
| 264 |
-
assay_dim: int,
|
| 265 |
-
molecule_dim: int,
|
| 266 |
-
projection_dim: int,
|
| 267 |
-
hidden_dim: int,
|
| 268 |
-
dropout: float,
|
| 269 |
-
metadata_dim: int = 0,
|
| 270 |
-
) -> None:
|
| 271 |
super().__init__()
|
| 272 |
-
self.
|
| 273 |
-
|
| 274 |
-
self.
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
nn.Dropout(dropout),
|
| 278 |
-
)
|
| 279 |
-
self.molecule_proj = nn.Sequential(
|
| 280 |
-
nn.Linear(molecule_dim, projection_dim),
|
| 281 |
-
nn.GELU(),
|
| 282 |
-
nn.Dropout(dropout),
|
| 283 |
-
)
|
| 284 |
-
self.scorer = nn.Sequential(
|
| 285 |
nn.Linear(projection_dim * 4, hidden_dim),
|
| 286 |
nn.GELU(),
|
| 287 |
nn.Dropout(dropout),
|
| 288 |
nn.Linear(hidden_dim, 1),
|
| 289 |
)
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
[
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
torch.abs(
|
| 304 |
],
|
| 305 |
dim=-1,
|
| 306 |
)
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
|
| 310 |
-
class
|
| 311 |
-
def __init__(
|
| 312 |
-
self
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
self.assay_encoder = assay_encoder
|
| 318 |
-
self.
|
| 319 |
-
self.
|
| 320 |
-
self.
|
| 321 |
-
self.
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
self.
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
[
|
| 335 |
-
|
| 336 |
-
show_progress_bar=False,
|
| 337 |
normalize_embeddings=True,
|
| 338 |
-
|
| 339 |
-
|
| 340 |
)[0].astype(np.float32)
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
if
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
|
| 390 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
model_path = Path(model_dir)
|
| 392 |
-
training_metadata = json.loads((model_path / "training_metadata.json").read_text())
|
| 393 |
checkpoint = torch.load(model_path / "best_model.pt", map_location="cpu", weights_only=False)
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
continue
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
{
|
| 421 |
-
"input_smiles":
|
| 422 |
-
"canonical_smiles":
|
| 423 |
-
"smiles_hash": smiles_sha256(
|
| 424 |
"score": float(score),
|
| 425 |
"valid": True,
|
| 426 |
}
|
| 427 |
)
|
| 428 |
-
|
| 429 |
-
if top_k:
|
| 430 |
-
|
| 431 |
-
|
|
|
|
| 432 |
|
| 433 |
|
| 434 |
def list_softmax_scores(scores: list[float], temperature: float = 1.0) -> list[float]:
|
|
|
|
| 39 |
"ASSAY_TYPE",
|
| 40 |
"TARGET_UNIPROT",
|
| 41 |
]
|
| 42 |
+
ASSAY_SECTION_RE = re.compile(
|
| 43 |
+
r"\[(ASSAY_TITLE|DESCRIPTION|ORGANISM|READOUT|ASSAY_FORMAT|ASSAY_TYPE|TARGET_UNIPROT)\]\n"
|
| 44 |
+
)
|
| 45 |
ORGANISM_ALIASES = {
|
| 46 |
"9606": "homo_sapiens",
|
| 47 |
"10090": "mus_musculus",
|
|
|
|
| 108 |
return "\n\n".join(f"[{key}]\n{values[key]}" for key in SECTION_ORDER)
|
| 109 |
|
| 110 |
|
| 111 |
+
def _format_assay_query(assay_text: str, task_description: str) -> str:
|
| 112 |
+
return f"Instruct: {task_description.strip()}\nQuery: {assay_text.strip()}"
|
| 113 |
+
|
| 114 |
+
|
| 115 |
def _parse_assay_sections(assay_text: str) -> dict[str, str]:
|
| 116 |
sections = {key: "" for key in SECTION_ORDER}
|
| 117 |
parts = ASSAY_SECTION_RE.split(assay_text)
|
|
|
|
| 225 |
fragments = Chem.GetMolFrags(mol)
|
| 226 |
formal_charge = sum(int(atom.GetFormalCharge()) for atom in mol.GetAtoms())
|
| 227 |
max_atomic_num = max(counts) if counts else 0
|
| 228 |
+
metal_atom_count = sum(
|
| 229 |
+
count for atomic_num, count in counts.items() if atomic_num not in ORGANIC_LIKE_ATOMIC_NUMBERS
|
| 230 |
+
)
|
| 231 |
halogen_count = sum(counts.get(item, 0) for item in (9, 17, 35, 53))
|
| 232 |
aromatic_atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())
|
| 233 |
values = {
|
|
|
|
| 266 |
|
| 267 |
|
| 268 |
class CompatibilityHead(nn.Module):
|
| 269 |
+
def __init__(self, *, assay_dim: int, molecule_dim: int, projection_dim: int, hidden_dim: int, dropout: float) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
super().__init__()
|
| 271 |
+
self.assay_norm = nn.LayerNorm(assay_dim)
|
| 272 |
+
self.assay_proj = nn.Linear(assay_dim, projection_dim)
|
| 273 |
+
self.mol_norm = nn.LayerNorm(molecule_dim)
|
| 274 |
+
self.mol_proj = nn.Linear(molecule_dim, projection_dim, bias=False)
|
| 275 |
+
self.score_mlp = nn.Sequential(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
nn.Linear(projection_dim * 4, hidden_dim),
|
| 277 |
nn.GELU(),
|
| 278 |
nn.Dropout(dropout),
|
| 279 |
nn.Linear(hidden_dim, 1),
|
| 280 |
)
|
| 281 |
+
self.dot_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
|
| 282 |
+
|
| 283 |
+
def encode_assay(self, assay_features: torch.Tensor) -> torch.Tensor:
|
| 284 |
+
vec = self.assay_proj(self.assay_norm(assay_features))
|
| 285 |
+
return F.normalize(vec, p=2, dim=-1)
|
| 286 |
+
|
| 287 |
+
def encode_molecule(self, molecule_features: torch.Tensor) -> torch.Tensor:
|
| 288 |
+
vec = self.mol_proj(self.mol_norm(molecule_features))
|
| 289 |
+
return F.normalize(vec, p=2, dim=-1)
|
| 290 |
+
|
| 291 |
+
def score_candidates(self, assay_features: torch.Tensor, candidate_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 292 |
+
assay_vec = self.encode_assay(assay_features)
|
| 293 |
+
mol_vec = self.encode_molecule(candidate_features)
|
| 294 |
+
assay_expand = assay_vec.unsqueeze(1).expand(-1, mol_vec.shape[1], -1)
|
| 295 |
+
dot_scores = (assay_expand * mol_vec).sum(dim=-1)
|
| 296 |
+
mlp_input = torch.cat(
|
| 297 |
[
|
| 298 |
+
assay_expand,
|
| 299 |
+
mol_vec,
|
| 300 |
+
assay_expand * mol_vec,
|
| 301 |
+
torch.abs(assay_expand - mol_vec),
|
| 302 |
],
|
| 303 |
dim=-1,
|
| 304 |
)
|
| 305 |
+
mlp_scores = self.score_mlp(mlp_input).squeeze(-1)
|
| 306 |
+
logits = dot_scores * self.dot_scale + mlp_scores
|
| 307 |
+
return logits, assay_vec, mol_vec
|
| 308 |
+
|
| 309 |
+
def score_pairs(self, assay_features: torch.Tensor, molecule_features: torch.Tensor) -> torch.Tensor:
|
| 310 |
+
assay_vec = self.encode_assay(assay_features)
|
| 311 |
+
mol_vec = self.encode_molecule(molecule_features)
|
| 312 |
+
dot_scores = (assay_vec * mol_vec).sum(dim=-1)
|
| 313 |
+
mlp_input = torch.cat(
|
| 314 |
+
[assay_vec, mol_vec, assay_vec * mol_vec, torch.abs(assay_vec - mol_vec)],
|
| 315 |
+
dim=-1,
|
| 316 |
+
)
|
| 317 |
+
mlp_scores = self.score_mlp(mlp_input).squeeze(-1)
|
| 318 |
+
return dot_scores * self.dot_scale + mlp_scores
|
| 319 |
|
| 320 |
|
| 321 |
+
class BioAssayAlignCompatibilityModel:
|
| 322 |
+
def __init__(
|
| 323 |
+
self,
|
| 324 |
+
assay_encoder: SentenceTransformer,
|
| 325 |
+
compatibility_head: CompatibilityHead,
|
| 326 |
+
*,
|
| 327 |
+
assay_task_description: str,
|
| 328 |
+
fingerprint_radii: tuple[int, ...],
|
| 329 |
+
fingerprint_bits: int,
|
| 330 |
+
use_chirality: bool,
|
| 331 |
+
use_maccs: bool,
|
| 332 |
+
use_rdkit_descriptors: bool,
|
| 333 |
+
descriptor_names: tuple[str, ...],
|
| 334 |
+
descriptor_mean: np.ndarray | None,
|
| 335 |
+
descriptor_std: np.ndarray | None,
|
| 336 |
+
use_assay_metadata_features: bool,
|
| 337 |
+
assay_metadata_dim: int,
|
| 338 |
+
) -> None:
|
| 339 |
self.assay_encoder = assay_encoder
|
| 340 |
+
self.compatibility_head = compatibility_head.eval()
|
| 341 |
+
self.assay_task_description = assay_task_description
|
| 342 |
+
self.fingerprint_radii = fingerprint_radii
|
| 343 |
+
self.fingerprint_bits = fingerprint_bits
|
| 344 |
+
self.use_chirality = use_chirality
|
| 345 |
+
self.use_maccs = use_maccs
|
| 346 |
+
self.use_rdkit_descriptors = use_rdkit_descriptors
|
| 347 |
+
self.descriptor_names = descriptor_names
|
| 348 |
+
self.descriptor_mean = descriptor_mean
|
| 349 |
+
self.descriptor_std = descriptor_std
|
| 350 |
+
self.use_assay_metadata_features = use_assay_metadata_features
|
| 351 |
+
self.assay_metadata_dim = assay_metadata_dim
|
| 352 |
+
|
| 353 |
+
def _build_assay_feature_array(self, assay_text: str) -> np.ndarray:
|
| 354 |
+
query = _format_assay_query(assay_text, self.assay_task_description)
|
| 355 |
+
assay_features = self.assay_encoder.encode(
|
| 356 |
+
[query],
|
| 357 |
+
batch_size=1,
|
|
|
|
| 358 |
normalize_embeddings=True,
|
| 359 |
+
show_progress_bar=False,
|
| 360 |
+
convert_to_numpy=True,
|
| 361 |
)[0].astype(np.float32)
|
| 362 |
+
if self.use_assay_metadata_features and self.assay_metadata_dim > 0:
|
| 363 |
+
metadata_vec = _assay_metadata_vector(assay_text, dim=self.assay_metadata_dim)
|
| 364 |
+
assay_features = np.concatenate([assay_features, metadata_vec.astype(np.float32)], axis=0)
|
| 365 |
+
return assay_features
|
| 366 |
+
|
| 367 |
+
def build_molecule_feature_matrix(self, smiles_values: list[str]) -> np.ndarray:
|
| 368 |
+
rows: list[np.ndarray] = []
|
| 369 |
+
for smiles in smiles_values:
|
| 370 |
+
rows.append(
|
| 371 |
+
_smiles_to_molecule_features(
|
| 372 |
+
smiles,
|
| 373 |
+
radii=self.fingerprint_radii,
|
| 374 |
+
n_bits=self.fingerprint_bits,
|
| 375 |
+
use_chirality=self.use_chirality,
|
| 376 |
+
use_maccs=self.use_maccs,
|
| 377 |
+
use_rdkit_descriptors=self.use_rdkit_descriptors,
|
| 378 |
+
descriptor_names=self.descriptor_names,
|
| 379 |
+
descriptor_mean=self.descriptor_mean,
|
| 380 |
+
descriptor_std=self.descriptor_std,
|
| 381 |
+
)
|
| 382 |
+
)
|
| 383 |
+
return np.stack(rows, axis=0).astype(np.float32)
|
| 384 |
+
|
| 385 |
+
def score(self, assay_text: str, smiles: str) -> float:
|
| 386 |
+
assay_features = self._build_assay_feature_array(assay_text)
|
| 387 |
+
molecule_features = self.build_molecule_feature_matrix([smiles])[0]
|
| 388 |
+
assay_tensor = torch.from_numpy(assay_features).unsqueeze(0)
|
| 389 |
+
molecule_tensor = torch.from_numpy(molecule_features).unsqueeze(0)
|
| 390 |
+
with torch.no_grad():
|
| 391 |
+
score = self.compatibility_head.score_pairs(assay_tensor, molecule_tensor)
|
| 392 |
+
return float(score.item())
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def _load_sentence_transformer(model_name: str):
|
| 396 |
+
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 397 |
+
encoder = SentenceTransformer(
|
| 398 |
+
model_name,
|
| 399 |
+
trust_remote_code=True,
|
| 400 |
+
model_kwargs={"torch_dtype": dtype},
|
| 401 |
+
)
|
| 402 |
+
if getattr(encoder, "tokenizer", None) is not None:
|
| 403 |
+
encoder.tokenizer.padding_side = "left"
|
| 404 |
+
return encoder
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _load_feature_spec(cfg: dict[str, Any], metadata: dict[str, Any], checkpoint: dict[str, Any]) -> dict[str, Any]:
|
| 408 |
+
spec = checkpoint.get("molecule_feature_spec") or metadata.get("molecule_feature_spec")
|
| 409 |
+
if spec:
|
| 410 |
+
return spec
|
| 411 |
+
radii = tuple(int(item) for item in (cfg.get("fingerprint_radii") or [cfg.get("fingerprint_radius", 2)]))
|
| 412 |
+
return {
|
| 413 |
+
"fingerprint_radii": list(radii),
|
| 414 |
+
"fingerprint_bits": int(cfg["fingerprint_bits"]),
|
| 415 |
+
"use_chirality": bool(cfg.get("use_chirality", False)),
|
| 416 |
+
"use_maccs": bool(cfg.get("use_maccs", False)),
|
| 417 |
+
"use_rdkit_descriptors": bool(cfg.get("use_rdkit_descriptors", False)),
|
| 418 |
+
"descriptor_names": [],
|
| 419 |
+
"descriptor_mean": None,
|
| 420 |
+
"descriptor_std": None,
|
| 421 |
+
}
|
| 422 |
|
| 423 |
|
| 424 |
+
def _smiles_to_molecule_features(
|
| 425 |
+
smiles: str,
|
| 426 |
+
*,
|
| 427 |
+
radii: tuple[int, ...],
|
| 428 |
+
n_bits: int,
|
| 429 |
+
use_chirality: bool,
|
| 430 |
+
use_maccs: bool,
|
| 431 |
+
use_rdkit_descriptors: bool,
|
| 432 |
+
descriptor_names: tuple[str, ...],
|
| 433 |
+
descriptor_mean: np.ndarray | None,
|
| 434 |
+
descriptor_std: np.ndarray | None,
|
| 435 |
+
) -> np.ndarray:
|
| 436 |
+
normalized = standardize_smiles_v2(smiles) or smiles
|
| 437 |
+
mol = Chem.MolFromSmiles(normalized)
|
| 438 |
+
if mol is None:
|
| 439 |
+
raise ValueError(f"Could not parse SMILES: {normalized}")
|
| 440 |
+
bit_blocks: list[np.ndarray] = [
|
| 441 |
+
_morgan_bits_from_mol(mol, radius=int(radius), n_bits=n_bits, use_chirality=use_chirality)
|
| 442 |
+
for radius in radii
|
| 443 |
+
]
|
| 444 |
+
if use_maccs:
|
| 445 |
+
bit_blocks.append(_maccs_bits_from_mol(mol))
|
| 446 |
+
output_blocks: list[np.ndarray] = [np.concatenate(bit_blocks, axis=0).astype(np.float32)]
|
| 447 |
+
if use_rdkit_descriptors and descriptor_names:
|
| 448 |
+
dense = _molecule_descriptor_vector(mol, names=descriptor_names)
|
| 449 |
+
if descriptor_mean is not None and descriptor_std is not None:
|
| 450 |
+
dense = (dense - descriptor_mean) / descriptor_std
|
| 451 |
+
output_blocks.append(dense.astype(np.float32))
|
| 452 |
+
return np.concatenate(output_blocks, axis=0).astype(np.float32)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def load_compatibility_model(model_dir: str | Path) -> BioAssayAlignCompatibilityModel:
|
| 456 |
model_path = Path(model_dir)
|
|
|
|
| 457 |
checkpoint = torch.load(model_path / "best_model.pt", map_location="cpu", weights_only=False)
|
| 458 |
+
metadata = json.loads((model_path / "training_metadata.json").read_text())
|
| 459 |
+
cfg = metadata["config"]
|
| 460 |
+
feature_spec = _load_feature_spec(cfg, metadata, checkpoint)
|
| 461 |
+
|
| 462 |
+
encoder = _load_sentence_transformer(checkpoint.get("assay_model_name") or cfg["assay_model_name"])
|
| 463 |
+
assay_dim = int(checkpoint["model_state_dict"]["assay_proj.weight"].shape[1])
|
| 464 |
+
molecule_dim = int(checkpoint["model_state_dict"]["mol_proj.weight"].shape[1])
|
| 465 |
+
head = CompatibilityHead(
|
| 466 |
+
assay_dim=assay_dim,
|
| 467 |
+
molecule_dim=molecule_dim,
|
| 468 |
+
projection_dim=int(cfg["projection_dim"]),
|
| 469 |
+
hidden_dim=int(cfg["hidden_dim"]),
|
| 470 |
+
dropout=float(cfg["dropout"]),
|
| 471 |
+
)
|
| 472 |
+
load_result = head.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 473 |
+
allowed_missing = {"mol_norm.weight", "mol_norm.bias"}
|
| 474 |
+
unexpected = set(load_result.unexpected_keys)
|
| 475 |
+
missing = set(load_result.missing_keys)
|
| 476 |
+
if unexpected or (missing - allowed_missing):
|
| 477 |
+
raise RuntimeError(
|
| 478 |
+
"Compatibility checkpoint load mismatch: "
|
| 479 |
+
f"unexpected={sorted(unexpected)} missing={sorted(missing)}"
|
| 480 |
+
)
|
| 481 |
+
return BioAssayAlignCompatibilityModel(
|
| 482 |
+
assay_encoder=encoder,
|
| 483 |
+
compatibility_head=head,
|
| 484 |
+
assay_task_description=checkpoint.get("assay_task_description") or cfg["assay_task_description"],
|
| 485 |
+
fingerprint_radii=tuple(int(item) for item in feature_spec.get("fingerprint_radii") or [2]),
|
| 486 |
+
fingerprint_bits=int(feature_spec.get("fingerprint_bits", cfg.get("fingerprint_bits", 2048))),
|
| 487 |
+
use_chirality=bool(feature_spec.get("use_chirality", cfg.get("use_chirality", False))),
|
| 488 |
+
use_maccs=bool(feature_spec.get("use_maccs", False)),
|
| 489 |
+
use_rdkit_descriptors=bool(feature_spec.get("use_rdkit_descriptors", False)),
|
| 490 |
+
descriptor_names=tuple(feature_spec.get("descriptor_names") or ()),
|
| 491 |
+
descriptor_mean=np.array(feature_spec["descriptor_mean"], dtype=np.float32)
|
| 492 |
+
if feature_spec.get("descriptor_mean") is not None
|
| 493 |
+
else None,
|
| 494 |
+
descriptor_std=np.array(feature_spec["descriptor_std"], dtype=np.float32)
|
| 495 |
+
if feature_spec.get("descriptor_std") is not None
|
| 496 |
+
else None,
|
| 497 |
+
use_assay_metadata_features=bool(cfg.get("use_assay_metadata_features", False)),
|
| 498 |
+
assay_metadata_dim=int(cfg.get("assay_metadata_dim", 0) or 0),
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def load_compatibility_model_from_hub(repo_id: str) -> BioAssayAlignCompatibilityModel:
|
| 503 |
+
snapshot_path = snapshot_download(
|
| 504 |
+
repo_id=repo_id,
|
| 505 |
+
repo_type="model",
|
| 506 |
+
allow_patterns=["best_model.pt", "training_metadata.json"],
|
| 507 |
+
)
|
| 508 |
+
return load_compatibility_model(snapshot_path)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def rank_compounds(
|
| 512 |
+
model: BioAssayAlignCompatibilityModel,
|
| 513 |
+
*,
|
| 514 |
+
assay_text: str,
|
| 515 |
+
smiles_list: list[str],
|
| 516 |
+
top_k: int | None = None,
|
| 517 |
+
) -> list[dict[str, Any]]:
|
| 518 |
+
if not smiles_list:
|
| 519 |
+
return []
|
| 520 |
+
|
| 521 |
+
assay_features = model._build_assay_feature_array(assay_text)
|
| 522 |
+
assay_tensor = torch.from_numpy(assay_features.astype(np.float32)).unsqueeze(0)
|
| 523 |
+
|
| 524 |
+
valid_items: list[tuple[str, str]] = []
|
| 525 |
+
invalid_items: list[dict[str, Any]] = []
|
| 526 |
+
for raw_smiles in smiles_list:
|
| 527 |
+
standardized = standardize_smiles_v2(raw_smiles)
|
| 528 |
+
if standardized is None:
|
| 529 |
+
invalid_items.append(
|
| 530 |
+
{
|
| 531 |
+
"input_smiles": raw_smiles,
|
| 532 |
+
"canonical_smiles": None,
|
| 533 |
+
"smiles_hash": None,
|
| 534 |
+
"score": None,
|
| 535 |
+
"valid": False,
|
| 536 |
+
"error": "invalid_smiles",
|
| 537 |
+
}
|
| 538 |
+
)
|
| 539 |
continue
|
| 540 |
+
valid_items.append((raw_smiles, standardized))
|
| 541 |
+
|
| 542 |
+
ranked_items: list[dict[str, Any]] = []
|
| 543 |
+
if valid_items:
|
| 544 |
+
feature_matrix = model.build_molecule_feature_matrix([item[1] for item in valid_items])
|
| 545 |
+
candidate_tensor = torch.from_numpy(feature_matrix).unsqueeze(0)
|
| 546 |
+
with torch.no_grad():
|
| 547 |
+
logits, _, _ = model.compatibility_head.score_candidates(
|
| 548 |
+
assay_tensor.to(dtype=torch.float32),
|
| 549 |
+
candidate_tensor.to(dtype=torch.float32),
|
| 550 |
+
)
|
| 551 |
+
scores = logits.squeeze(0).cpu().numpy().tolist()
|
| 552 |
+
for (raw_smiles, canonical), score in zip(valid_items, scores, strict=True):
|
| 553 |
+
ranked_items.append(
|
| 554 |
{
|
| 555 |
+
"input_smiles": raw_smiles,
|
| 556 |
+
"canonical_smiles": canonical,
|
| 557 |
+
"smiles_hash": smiles_sha256(canonical),
|
| 558 |
"score": float(score),
|
| 559 |
"valid": True,
|
| 560 |
}
|
| 561 |
)
|
| 562 |
+
ranked_items.sort(key=lambda item: item["score"], reverse=True)
|
| 563 |
+
if top_k is not None and top_k > 0:
|
| 564 |
+
ranked_items = ranked_items[:top_k]
|
| 565 |
+
|
| 566 |
+
return ranked_items + invalid_items
|
| 567 |
|
| 568 |
|
| 569 |
def list_softmax_scores(scores: list[float], temperature: float = 1.0) -> list[float]:
|