import os import numpy as np import pickle from typing import Callable, List, Dict, Any, Optional from rdkit import Chem import faiss import torch from torch.utils.data import Dataset, DataLoader from tqdm import tqdm import dgl class MoleculeDataset(Dataset): """Converts SMILES to DGL graphs in parallel via DataLoader workers.""" def __init__(self, smiles_dict, smiles_preprocess): self.items = list(smiles_dict.items()) self.smiles_preprocess = smiles_preprocess def __len__(self): return len(self.items) def __getitem__(self, idx): mol_id, smi = self.items[idx] try: graph = self.smiles_preprocess(smi) return mol_id, graph, None except Exception as e: return mol_id, None, str(e) def collate_graphs(batch): """Custom collation: keep only valid graphs.""" valid = [(mid, g) for mid, g, err in batch if g is not None] if not valid: return [], None mol_ids, graphs = zip(*valid) batched_graph = dgl.batch(graphs) return mol_ids, batched_graph class SpectraMoleculeRetriever: """ Two-stage spectra–molecule retrieval system with hierarchical metadata filtering: 1. Coarse retrieval via FAISS on global embeddings. 2. Fine-grained reranking via custom similarity (e.g., FILIP alignment). 3. Supports fast subset search by class, superclass, or pathway. """ def __init__( self, molecule_encoder, spectra_encoder, fine_similarity_fn: Callable[[Any, Any], float], smiles_preprocess: Callable[[str], Any], device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ): """ Args: molecule_encoder: callable with methods: - global_embedding(mol) - node_embeddings(mol) spectra_encoder: callable with methods: - global_embedding(spectrum) - token_embeddings(spectrum) fine_similarity_fn: function for fine-grained similarity. smiles_preprocess: preprocessing function for SMILES → molecule object. device: where to run encoders. """ self.molecule_encoder = molecule_encoder self.spectra_encoder = spectra_encoder self.fine_similarity_fn = fine_similarity_fn self.smiles_preprocess = smiles_preprocess self.device = device # Storage self.molecule_db: Dict[str, Any] = {} # mol_id → mol object self.node_cache: Dict[str, Any] = {} # mol_id → node embeddings self.metadata: Dict[str, Dict[str, List[str]]] = {} # e.g. {"class": {"lipid": [mol1, mol2], ...}} self.molecule_ids: Optional[np.ndarray] = None self.global_embeddings: Optional[np.ndarray] = None self.index: Optional[faiss.Index] = None self.smiles_dict: Optional[Dict[str, str]] = None # mol_id → smiles self.failed_mols = [] # set model to eval mode and move to device self.molecule_encoder.eval() self.spectra_encoder.eval() self.molecule_encoder.to(self.device) self.spectra_encoder.to(self.device) # ------------------------------- # Database building & saving # ------------------------------- def build_database( self, smiles_dict: dict, metadata=None, cache_nodes: bool = False, batch_size: int = 64, num_workers: int = 25, pooling: str = "max", # or "sum", "mean" ): """ Parallelized database construction using PyTorch DataLoader for SMILES → DGLGraph conversion and batched GPU encoding. Args: smiles_dict: dict {mol_id: smiles} metadata: hierarchical dict for class/superclass/pathway cache_nodes: if True, store node embeddings for fine-grained search batch_size: number of molecules per GPU batch num_workers: parallel CPU workers for SMILES parsing pooling: global pooling type ("max" | "sum" | "mean") """ print("Building molecule database with PyTorch DataLoader parallelization...") # set up pooling if pooling == "max": self.pooling = dgl.nn.pytorch.glob.MaxPooling() elif pooling == "sum": self.pooling = dgl.nn.pytorch.glob.SumPooling() elif pooling == "mean": self.pooling = dgl.nn.pytorch.glob.MeanPooling() else: raise ValueError(f"Unsupported pooling: {pooling}") dataset = MoleculeDataset(smiles_dict, self.smiles_preprocess) loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_graphs, pin_memory=True, ) mol_ids_all, mol_objs, mol_embs = [], [], [] failed_mols = [] node_cache = {} with torch.no_grad(): for mol_ids, batched_graph in tqdm(loader, desc="Encoding molecules"): if batched_graph is None: # All failed in this batch continue try: batched_graph = batched_graph.to(self.device) node_repr = self.molecule_encoder(batched_graph, batched_graph.ndata['h']) global_emb = self.pooling(batched_graph,node_repr) # Normalize embeddings emb_np = global_emb.detach().cpu().numpy() emb_np /= np.linalg.norm(emb_np, axis=1, keepdims=True) mol_ids_all.extend(mol_ids) mol_objs.extend([batched_graph] * len(mol_ids)) mol_embs.append(emb_np) # Optionally store node embeddings for fine-grained search if cache_nodes: # Split batched node embeddings into per-graph chunks node_embs = dgl.unbatch(batched_graph) for mol_id, mol_graph in zip(mol_ids, node_embs): node_cache[mol_id] = mol_graph.ndata['h'].detach().cpu() except Exception as e: failed_mols.extend(mol_ids) print(f"[Warning] Failed to encode batch with molecules {mol_ids}: {e}") continue if not mol_embs: raise RuntimeError("No valid molecules were successfully encoded.") self.failed_mols = failed_mols self.smiles_dict = smiles_dict self.molecule_db = dict(zip(mol_ids_all, mol_objs)) self.molecule_ids = np.array(mol_ids_all) self.global_embeddings = np.concatenate(mol_embs, axis=0) self.metadata = metadata or {} self.node_cache.update(node_cache) self._build_faiss_index() print(f"Database built with {len(self.molecule_ids)} molecules " f"({len(self.failed_mols) + (len(smiles_dict) - len(self.molecule_ids))} failed).") def _build_faiss_index(self): d = self.global_embeddings.shape[1] self.index = faiss.IndexFlatIP(d) self.index.add(self.global_embeddings) print(f"FAISS index built with {len(self.molecule_ids)} embeddings.") def save_database(self, path: str): """Save molecule database and embeddings.""" data = { "molecule_ids": self.molecule_ids, "global_embeddings": self.global_embeddings, "metadata": self.metadata, "node_cache": self.node_cache, "smiles_dict": self.smiles_dict, } with open(path, "wb") as f: pickle.dump(data, f) print(f"Database saved to {path}") def load_database(self, path: str): """Load molecule database and rebuild FAISS index.""" with open(path, "rb") as f: data = pickle.load(f) self.molecule_ids = data["molecule_ids"] self.global_embeddings = data["global_embeddings"] self.metadata = data.get("metadata", {}) self.node_cache = data.get("node_cache", {}) self.smiles_dict = data.get("smiles_dict", {}) self._build_faiss_index() print(f"Database loaded from {path}") # ------------------------------- # Filtering utilities # ------------------------------- def _get_filtered_indices(self, subset: Optional[Dict[str, str]] = None) -> np.ndarray: """ Retrieve indices for molecules matching a given metadata subset. Example subset: {"class": "lipid"} or {"pathway": "glycolysis"} """ if not subset: return np.arange(len(self.molecule_ids)) key, value = next(iter(subset.items())) if key not in self.metadata or value not in self.metadata[key]: print(f"[Warning] No molecules found for {key}={value}") return np.array([], dtype=int) mol_ids = self.metadata[key][value] id_to_idx = {m: i for i, m in enumerate(self.molecule_ids)} selected = [id_to_idx[m] for m in mol_ids if m in id_to_idx] return np.array(selected, dtype=int) # ------------------------------- # Retrieval # ------------------------------- def coarse_search(self, spectrum, top_k: int = 256, subset: Optional[Dict[str, str]] = None): """ Retrieve top-k candidates using FAISS, optionally restricted to subset metadata. """ with torch.no_grad(): spectrum = spectrum.to(self.device) z_spec = self.spectra_encoder(spectrum).sum(axis=0) z_spec = z_spec.detach().cpu().numpy() if hasattr(z_spec, "detach") else np.asarray(z_spec) z_spec = z_spec / np.linalg.norm(z_spec) subset_idx = self._get_filtered_indices(subset) if subset_idx.size == 0: return [], [] # subset FAISS index emb_subset = self.global_embeddings[subset_idx] index_subset = faiss.IndexFlatIP(emb_subset.shape[1]) index_subset.add(emb_subset) sims, idxs = index_subset.search(z_spec[None, :], min(top_k, len(subset_idx))) candidate_ids = self.molecule_ids[subset_idx[idxs[0]]] return candidate_ids, sims[0] def fine_rerank(self, spectrum, candidate_ids: List[str], top_k: int = 50): """ Compute fine-grained similarity for the candidates and rerank. """ spectrum = spectrum.to(self.device) with torch.no_grad(): z_spec_tokens = self.spectra_encoder(spectrum) scores = [] for mol_id in candidate_ids: if mol_id in self.node_cache: mol_tokens = self.node_cache[mol_id] elif mol_id in self.molecule_db: mol = self.molecule_db[mol_id].to(self.device) mol_tokens = self.molecule_encoder(mol) else: mol = self.smiles_preprocess(self.smiles_dict[mol_id]) mol = mol.to(self.device) mol_tokens = self.molecule_encoder(mol) s = self.fine_similarity_fn(z_spec_tokens, mol_tokens).item() scores.append((mol_id, s)) scores.sort(key=lambda x: x[1], reverse=True) return scores[:top_k] def search( self, spectrum, coarse_k: int = 256, fine_k: int = 50, subset: Optional[Dict[str, str]] = None, ): """ Full two-stage search pipeline with optional subset filtering. """ candidate_ids, _ = self.coarse_search(spectrum, top_k=coarse_k, subset=subset) if len(candidate_ids) == 0: return [] ranked = self.fine_rerank(spectrum, candidate_ids, top_k=fine_k) return ranked if __name__ == "__main__": import sys sys.path.insert(0, "/data/yzhouc01/FILIP-MS") from flare.utils.data import get_spec_featurizer, get_mol_featurizer from flare.utils.models import get_model from flare.utils.mol_search import SpectraMoleculeRetriever from flare.utils.general import filip_similarity_single import yaml metadata = { "class": { "lipid": ["mol1", "mol2"], "peptide": ["mol3"] }, "pathway": { "beta-oxidation": ["mol1"], "glycolysis": ["mol2", "mol3"] } } smiles_dict = { "mol1": "CCO", "mol2": "CCN", "mol3": "CCC" } # Load model and data param_pth = '/data/yzhouc01/cancer/flare.yaml' with open(param_pth) as f: params = yaml.load(f, Loader=yaml.FullLoader) spec_featurizer = get_spec_featurizer(params['spectra_view'], params) mol_featurizer = get_mol_featurizer(params['molecule_view'], params) # load model checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250930_optimized_flare_42/epoch=1959-train_loss=0.08.ckpt" params['checkpoint_pth'] = checkpoint_pth model = get_model(params['model'], params) specMolRetriever = SpectraMoleculeRetriever( molecule_encoder=model.mol_enc_model, spectra_encoder=model.spec_enc_model, fine_similarity_fn=filip_similarity_single, smiles_preprocess=mol_featurizer ) specMolRetriever.build_database(smiles_dict, metadata=metadata, cache_nodes=True) # Filter search to molecules in a specific pathway # results = specMolRetriever.search(spectrum, subset={"pathway": "beta-oxidation"}) # for mol_id, score in results[:10]: # print(f"{mol_id}: {score:.3f}")