FLARE / flare /utils /mol_search.py
yzhouchen001's picture
update
19a4dfc
import os
import numpy as np
import pickle
from typing import Callable, List, Dict, Any, Optional
from rdkit import Chem
import faiss
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import dgl
class MoleculeDataset(Dataset):
"""Converts SMILES to DGL graphs in parallel via DataLoader workers."""
def __init__(self, smiles_dict, smiles_preprocess):
self.items = list(smiles_dict.items())
self.smiles_preprocess = smiles_preprocess
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
mol_id, smi = self.items[idx]
try:
graph = self.smiles_preprocess(smi)
return mol_id, graph, None
except Exception as e:
return mol_id, None, str(e)
def collate_graphs(batch):
"""Custom collation: keep only valid graphs."""
valid = [(mid, g) for mid, g, err in batch if g is not None]
if not valid:
return [], None
mol_ids, graphs = zip(*valid)
batched_graph = dgl.batch(graphs)
return mol_ids, batched_graph
class SpectraMoleculeRetriever:
"""
Two-stage spectra–molecule retrieval system with hierarchical metadata filtering:
1. Coarse retrieval via FAISS on global embeddings.
2. Fine-grained reranking via custom similarity (e.g., FILIP alignment).
3. Supports fast subset search by class, superclass, or pathway.
"""
def __init__(
self,
molecule_encoder,
spectra_encoder,
fine_similarity_fn: Callable[[Any, Any], float],
smiles_preprocess: Callable[[str], Any],
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
):
"""
Args:
molecule_encoder: callable with methods:
- global_embedding(mol)
- node_embeddings(mol)
spectra_encoder: callable with methods:
- global_embedding(spectrum)
- token_embeddings(spectrum)
fine_similarity_fn: function for fine-grained similarity.
smiles_preprocess: preprocessing function for SMILES → molecule object.
device: where to run encoders.
"""
self.molecule_encoder = molecule_encoder
self.spectra_encoder = spectra_encoder
self.fine_similarity_fn = fine_similarity_fn
self.smiles_preprocess = smiles_preprocess
self.device = device
# Storage
self.molecule_db: Dict[str, Any] = {} # mol_id → mol object
self.node_cache: Dict[str, Any] = {} # mol_id → node embeddings
self.metadata: Dict[str, Dict[str, List[str]]] = {} # e.g. {"class": {"lipid": [mol1, mol2], ...}}
self.molecule_ids: Optional[np.ndarray] = None
self.global_embeddings: Optional[np.ndarray] = None
self.index: Optional[faiss.Index] = None
self.smiles_dict: Optional[Dict[str, str]] = None # mol_id → smiles
self.failed_mols = []
# set model to eval mode and move to device
self.molecule_encoder.eval()
self.spectra_encoder.eval()
self.molecule_encoder.to(self.device)
self.spectra_encoder.to(self.device)
# -------------------------------
# Database building & saving
# -------------------------------
def build_database(
self,
smiles_dict: dict,
metadata=None,
cache_nodes: bool = False,
batch_size: int = 64,
num_workers: int = 25,
pooling: str = "max", # or "sum", "mean"
):
"""
Parallelized database construction using PyTorch DataLoader for
SMILES → DGLGraph conversion and batched GPU encoding.
Args:
smiles_dict: dict {mol_id: smiles}
metadata: hierarchical dict for class/superclass/pathway
cache_nodes: if True, store node embeddings for fine-grained search
batch_size: number of molecules per GPU batch
num_workers: parallel CPU workers for SMILES parsing
pooling: global pooling type ("max" | "sum" | "mean")
"""
print("Building molecule database with PyTorch DataLoader parallelization...")
# set up pooling
if pooling == "max":
self.pooling = dgl.nn.pytorch.glob.MaxPooling()
elif pooling == "sum":
self.pooling = dgl.nn.pytorch.glob.SumPooling()
elif pooling == "mean":
self.pooling = dgl.nn.pytorch.glob.MeanPooling()
else:
raise ValueError(f"Unsupported pooling: {pooling}")
dataset = MoleculeDataset(smiles_dict, self.smiles_preprocess)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_graphs,
pin_memory=True,
)
mol_ids_all, mol_objs, mol_embs = [], [], []
failed_mols = []
node_cache = {}
with torch.no_grad():
for mol_ids, batched_graph in tqdm(loader, desc="Encoding molecules"):
if batched_graph is None:
# All failed in this batch
continue
try:
batched_graph = batched_graph.to(self.device)
node_repr = self.molecule_encoder(batched_graph, batched_graph.ndata['h'])
global_emb = self.pooling(batched_graph,node_repr)
# Normalize embeddings
emb_np = global_emb.detach().cpu().numpy()
emb_np /= np.linalg.norm(emb_np, axis=1, keepdims=True)
mol_ids_all.extend(mol_ids)
mol_objs.extend([batched_graph] * len(mol_ids))
mol_embs.append(emb_np)
# Optionally store node embeddings for fine-grained search
if cache_nodes:
# Split batched node embeddings into per-graph chunks
node_embs = dgl.unbatch(batched_graph)
for mol_id, mol_graph in zip(mol_ids, node_embs):
node_cache[mol_id] = mol_graph.ndata['h'].detach().cpu()
except Exception as e:
failed_mols.extend(mol_ids)
print(f"[Warning] Failed to encode batch with molecules {mol_ids}: {e}")
continue
if not mol_embs:
raise RuntimeError("No valid molecules were successfully encoded.")
self.failed_mols = failed_mols
self.smiles_dict = smiles_dict
self.molecule_db = dict(zip(mol_ids_all, mol_objs))
self.molecule_ids = np.array(mol_ids_all)
self.global_embeddings = np.concatenate(mol_embs, axis=0)
self.metadata = metadata or {}
self.node_cache.update(node_cache)
self._build_faiss_index()
print(f"Database built with {len(self.molecule_ids)} molecules "
f"({len(self.failed_mols) + (len(smiles_dict) - len(self.molecule_ids))} failed).")
def _build_faiss_index(self):
d = self.global_embeddings.shape[1]
self.index = faiss.IndexFlatIP(d)
self.index.add(self.global_embeddings)
print(f"FAISS index built with {len(self.molecule_ids)} embeddings.")
def save_database(self, path: str):
"""Save molecule database and embeddings."""
data = {
"molecule_ids": self.molecule_ids,
"global_embeddings": self.global_embeddings,
"metadata": self.metadata,
"node_cache": self.node_cache,
"smiles_dict": self.smiles_dict,
}
with open(path, "wb") as f:
pickle.dump(data, f)
print(f"Database saved to {path}")
def load_database(self, path: str):
"""Load molecule database and rebuild FAISS index."""
with open(path, "rb") as f:
data = pickle.load(f)
self.molecule_ids = data["molecule_ids"]
self.global_embeddings = data["global_embeddings"]
self.metadata = data.get("metadata", {})
self.node_cache = data.get("node_cache", {})
self.smiles_dict = data.get("smiles_dict", {})
self._build_faiss_index()
print(f"Database loaded from {path}")
# -------------------------------
# Filtering utilities
# -------------------------------
def _get_filtered_indices(self, subset: Optional[Dict[str, str]] = None) -> np.ndarray:
"""
Retrieve indices for molecules matching a given metadata subset.
Example subset: {"class": "lipid"} or {"pathway": "glycolysis"}
"""
if not subset:
return np.arange(len(self.molecule_ids))
key, value = next(iter(subset.items()))
if key not in self.metadata or value not in self.metadata[key]:
print(f"[Warning] No molecules found for {key}={value}")
return np.array([], dtype=int)
mol_ids = self.metadata[key][value]
id_to_idx = {m: i for i, m in enumerate(self.molecule_ids)}
selected = [id_to_idx[m] for m in mol_ids if m in id_to_idx]
return np.array(selected, dtype=int)
# -------------------------------
# Retrieval
# -------------------------------
def coarse_search(self, spectrum, top_k: int = 256, subset: Optional[Dict[str, str]] = None):
"""
Retrieve top-k candidates using FAISS, optionally restricted to subset metadata.
"""
with torch.no_grad():
spectrum = spectrum.to(self.device)
z_spec = self.spectra_encoder(spectrum).sum(axis=0)
z_spec = z_spec.detach().cpu().numpy() if hasattr(z_spec, "detach") else np.asarray(z_spec)
z_spec = z_spec / np.linalg.norm(z_spec)
subset_idx = self._get_filtered_indices(subset)
if subset_idx.size == 0:
return [], []
# subset FAISS index
emb_subset = self.global_embeddings[subset_idx]
index_subset = faiss.IndexFlatIP(emb_subset.shape[1])
index_subset.add(emb_subset)
sims, idxs = index_subset.search(z_spec[None, :], min(top_k, len(subset_idx)))
candidate_ids = self.molecule_ids[subset_idx[idxs[0]]]
return candidate_ids, sims[0]
def fine_rerank(self, spectrum, candidate_ids: List[str], top_k: int = 50):
"""
Compute fine-grained similarity for the candidates and rerank.
"""
spectrum = spectrum.to(self.device)
with torch.no_grad():
z_spec_tokens = self.spectra_encoder(spectrum)
scores = []
for mol_id in candidate_ids:
if mol_id in self.node_cache:
mol_tokens = self.node_cache[mol_id]
elif mol_id in self.molecule_db:
mol = self.molecule_db[mol_id].to(self.device)
mol_tokens = self.molecule_encoder(mol)
else:
mol = self.smiles_preprocess(self.smiles_dict[mol_id])
mol = mol.to(self.device)
mol_tokens = self.molecule_encoder(mol)
s = self.fine_similarity_fn(z_spec_tokens, mol_tokens).item()
scores.append((mol_id, s))
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:top_k]
def search(
self,
spectrum,
coarse_k: int = 256,
fine_k: int = 50,
subset: Optional[Dict[str, str]] = None,
):
"""
Full two-stage search pipeline with optional subset filtering.
"""
candidate_ids, _ = self.coarse_search(spectrum, top_k=coarse_k, subset=subset)
if len(candidate_ids) == 0:
return []
ranked = self.fine_rerank(spectrum, candidate_ids, top_k=fine_k)
return ranked
if __name__ == "__main__":
import sys
sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
from flare.utils.data import get_spec_featurizer, get_mol_featurizer
from flare.utils.models import get_model
from flare.utils.mol_search import SpectraMoleculeRetriever
from flare.utils.general import filip_similarity_single
import yaml
metadata = {
"class": {
"lipid": ["mol1", "mol2"],
"peptide": ["mol3"]
},
"pathway": {
"beta-oxidation": ["mol1"],
"glycolysis": ["mol2", "mol3"]
}
}
smiles_dict = {
"mol1": "CCO",
"mol2": "CCN",
"mol3": "CCC"
}
# Load model and data
param_pth = '/data/yzhouc01/cancer/flare.yaml'
with open(param_pth) as f:
params = yaml.load(f, Loader=yaml.FullLoader)
spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
# load model
checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250930_optimized_flare_42/epoch=1959-train_loss=0.08.ckpt"
params['checkpoint_pth'] = checkpoint_pth
model = get_model(params['model'], params)
specMolRetriever = SpectraMoleculeRetriever(
molecule_encoder=model.mol_enc_model,
spectra_encoder=model.spec_enc_model,
fine_similarity_fn=filip_similarity_single,
smiles_preprocess=mol_featurizer
)
specMolRetriever.build_database(smiles_dict, metadata=metadata, cache_nodes=True)
# Filter search to molecules in a specific pathway
# results = specMolRetriever.search(spectrum, subset={"pathway": "beta-oxidation"})
# for mol_id, score in results[:10]:
# print(f"{mol_id}: {score:.3f}")