Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import pickle | |
| from typing import Callable, List, Dict, Any, Optional | |
| from rdkit import Chem | |
| import faiss | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from tqdm import tqdm | |
| import dgl | |
| class MoleculeDataset(Dataset): | |
| """Converts SMILES to DGL graphs in parallel via DataLoader workers.""" | |
| def __init__(self, smiles_dict, smiles_preprocess): | |
| self.items = list(smiles_dict.items()) | |
| self.smiles_preprocess = smiles_preprocess | |
| def __len__(self): | |
| return len(self.items) | |
| def __getitem__(self, idx): | |
| mol_id, smi = self.items[idx] | |
| try: | |
| graph = self.smiles_preprocess(smi) | |
| return mol_id, graph, None | |
| except Exception as e: | |
| return mol_id, None, str(e) | |
| def collate_graphs(batch): | |
| """Custom collation: keep only valid graphs.""" | |
| valid = [(mid, g) for mid, g, err in batch if g is not None] | |
| if not valid: | |
| return [], None | |
| mol_ids, graphs = zip(*valid) | |
| batched_graph = dgl.batch(graphs) | |
| return mol_ids, batched_graph | |
| class SpectraMoleculeRetriever: | |
| """ | |
| Two-stage spectra–molecule retrieval system with hierarchical metadata filtering: | |
| 1. Coarse retrieval via FAISS on global embeddings. | |
| 2. Fine-grained reranking via custom similarity (e.g., FILIP alignment). | |
| 3. Supports fast subset search by class, superclass, or pathway. | |
| """ | |
| def __init__( | |
| self, | |
| molecule_encoder, | |
| spectra_encoder, | |
| fine_similarity_fn: Callable[[Any, Any], float], | |
| smiles_preprocess: Callable[[str], Any], | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ): | |
| """ | |
| Args: | |
| molecule_encoder: callable with methods: | |
| - global_embedding(mol) | |
| - node_embeddings(mol) | |
| spectra_encoder: callable with methods: | |
| - global_embedding(spectrum) | |
| - token_embeddings(spectrum) | |
| fine_similarity_fn: function for fine-grained similarity. | |
| smiles_preprocess: preprocessing function for SMILES → molecule object. | |
| device: where to run encoders. | |
| """ | |
| self.molecule_encoder = molecule_encoder | |
| self.spectra_encoder = spectra_encoder | |
| self.fine_similarity_fn = fine_similarity_fn | |
| self.smiles_preprocess = smiles_preprocess | |
| self.device = device | |
| # Storage | |
| self.molecule_db: Dict[str, Any] = {} # mol_id → mol object | |
| self.node_cache: Dict[str, Any] = {} # mol_id → node embeddings | |
| self.metadata: Dict[str, Dict[str, List[str]]] = {} # e.g. {"class": {"lipid": [mol1, mol2], ...}} | |
| self.molecule_ids: Optional[np.ndarray] = None | |
| self.global_embeddings: Optional[np.ndarray] = None | |
| self.index: Optional[faiss.Index] = None | |
| self.smiles_dict: Optional[Dict[str, str]] = None # mol_id → smiles | |
| self.failed_mols = [] | |
| # set model to eval mode and move to device | |
| self.molecule_encoder.eval() | |
| self.spectra_encoder.eval() | |
| self.molecule_encoder.to(self.device) | |
| self.spectra_encoder.to(self.device) | |
| # ------------------------------- | |
| # Database building & saving | |
| # ------------------------------- | |
| def build_database( | |
| self, | |
| smiles_dict: dict, | |
| metadata=None, | |
| cache_nodes: bool = False, | |
| batch_size: int = 64, | |
| num_workers: int = 25, | |
| pooling: str = "max", # or "sum", "mean" | |
| ): | |
| """ | |
| Parallelized database construction using PyTorch DataLoader for | |
| SMILES → DGLGraph conversion and batched GPU encoding. | |
| Args: | |
| smiles_dict: dict {mol_id: smiles} | |
| metadata: hierarchical dict for class/superclass/pathway | |
| cache_nodes: if True, store node embeddings for fine-grained search | |
| batch_size: number of molecules per GPU batch | |
| num_workers: parallel CPU workers for SMILES parsing | |
| pooling: global pooling type ("max" | "sum" | "mean") | |
| """ | |
| print("Building molecule database with PyTorch DataLoader parallelization...") | |
| # set up pooling | |
| if pooling == "max": | |
| self.pooling = dgl.nn.pytorch.glob.MaxPooling() | |
| elif pooling == "sum": | |
| self.pooling = dgl.nn.pytorch.glob.SumPooling() | |
| elif pooling == "mean": | |
| self.pooling = dgl.nn.pytorch.glob.MeanPooling() | |
| else: | |
| raise ValueError(f"Unsupported pooling: {pooling}") | |
| dataset = MoleculeDataset(smiles_dict, self.smiles_preprocess) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| collate_fn=collate_graphs, | |
| pin_memory=True, | |
| ) | |
| mol_ids_all, mol_objs, mol_embs = [], [], [] | |
| failed_mols = [] | |
| node_cache = {} | |
| with torch.no_grad(): | |
| for mol_ids, batched_graph in tqdm(loader, desc="Encoding molecules"): | |
| if batched_graph is None: | |
| # All failed in this batch | |
| continue | |
| try: | |
| batched_graph = batched_graph.to(self.device) | |
| node_repr = self.molecule_encoder(batched_graph, batched_graph.ndata['h']) | |
| global_emb = self.pooling(batched_graph,node_repr) | |
| # Normalize embeddings | |
| emb_np = global_emb.detach().cpu().numpy() | |
| emb_np /= np.linalg.norm(emb_np, axis=1, keepdims=True) | |
| mol_ids_all.extend(mol_ids) | |
| mol_objs.extend([batched_graph] * len(mol_ids)) | |
| mol_embs.append(emb_np) | |
| # Optionally store node embeddings for fine-grained search | |
| if cache_nodes: | |
| # Split batched node embeddings into per-graph chunks | |
| node_embs = dgl.unbatch(batched_graph) | |
| for mol_id, mol_graph in zip(mol_ids, node_embs): | |
| node_cache[mol_id] = mol_graph.ndata['h'].detach().cpu() | |
| except Exception as e: | |
| failed_mols.extend(mol_ids) | |
| print(f"[Warning] Failed to encode batch with molecules {mol_ids}: {e}") | |
| continue | |
| if not mol_embs: | |
| raise RuntimeError("No valid molecules were successfully encoded.") | |
| self.failed_mols = failed_mols | |
| self.smiles_dict = smiles_dict | |
| self.molecule_db = dict(zip(mol_ids_all, mol_objs)) | |
| self.molecule_ids = np.array(mol_ids_all) | |
| self.global_embeddings = np.concatenate(mol_embs, axis=0) | |
| self.metadata = metadata or {} | |
| self.node_cache.update(node_cache) | |
| self._build_faiss_index() | |
| print(f"Database built with {len(self.molecule_ids)} molecules " | |
| f"({len(self.failed_mols) + (len(smiles_dict) - len(self.molecule_ids))} failed).") | |
| def _build_faiss_index(self): | |
| d = self.global_embeddings.shape[1] | |
| self.index = faiss.IndexFlatIP(d) | |
| self.index.add(self.global_embeddings) | |
| print(f"FAISS index built with {len(self.molecule_ids)} embeddings.") | |
| def save_database(self, path: str): | |
| """Save molecule database and embeddings.""" | |
| data = { | |
| "molecule_ids": self.molecule_ids, | |
| "global_embeddings": self.global_embeddings, | |
| "metadata": self.metadata, | |
| "node_cache": self.node_cache, | |
| "smiles_dict": self.smiles_dict, | |
| } | |
| with open(path, "wb") as f: | |
| pickle.dump(data, f) | |
| print(f"Database saved to {path}") | |
| def load_database(self, path: str): | |
| """Load molecule database and rebuild FAISS index.""" | |
| with open(path, "rb") as f: | |
| data = pickle.load(f) | |
| self.molecule_ids = data["molecule_ids"] | |
| self.global_embeddings = data["global_embeddings"] | |
| self.metadata = data.get("metadata", {}) | |
| self.node_cache = data.get("node_cache", {}) | |
| self.smiles_dict = data.get("smiles_dict", {}) | |
| self._build_faiss_index() | |
| print(f"Database loaded from {path}") | |
| # ------------------------------- | |
| # Filtering utilities | |
| # ------------------------------- | |
| def _get_filtered_indices(self, subset: Optional[Dict[str, str]] = None) -> np.ndarray: | |
| """ | |
| Retrieve indices for molecules matching a given metadata subset. | |
| Example subset: {"class": "lipid"} or {"pathway": "glycolysis"} | |
| """ | |
| if not subset: | |
| return np.arange(len(self.molecule_ids)) | |
| key, value = next(iter(subset.items())) | |
| if key not in self.metadata or value not in self.metadata[key]: | |
| print(f"[Warning] No molecules found for {key}={value}") | |
| return np.array([], dtype=int) | |
| mol_ids = self.metadata[key][value] | |
| id_to_idx = {m: i for i, m in enumerate(self.molecule_ids)} | |
| selected = [id_to_idx[m] for m in mol_ids if m in id_to_idx] | |
| return np.array(selected, dtype=int) | |
| # ------------------------------- | |
| # Retrieval | |
| # ------------------------------- | |
| def coarse_search(self, spectrum, top_k: int = 256, subset: Optional[Dict[str, str]] = None): | |
| """ | |
| Retrieve top-k candidates using FAISS, optionally restricted to subset metadata. | |
| """ | |
| with torch.no_grad(): | |
| spectrum = spectrum.to(self.device) | |
| z_spec = self.spectra_encoder(spectrum).sum(axis=0) | |
| z_spec = z_spec.detach().cpu().numpy() if hasattr(z_spec, "detach") else np.asarray(z_spec) | |
| z_spec = z_spec / np.linalg.norm(z_spec) | |
| subset_idx = self._get_filtered_indices(subset) | |
| if subset_idx.size == 0: | |
| return [], [] | |
| # subset FAISS index | |
| emb_subset = self.global_embeddings[subset_idx] | |
| index_subset = faiss.IndexFlatIP(emb_subset.shape[1]) | |
| index_subset.add(emb_subset) | |
| sims, idxs = index_subset.search(z_spec[None, :], min(top_k, len(subset_idx))) | |
| candidate_ids = self.molecule_ids[subset_idx[idxs[0]]] | |
| return candidate_ids, sims[0] | |
| def fine_rerank(self, spectrum, candidate_ids: List[str], top_k: int = 50): | |
| """ | |
| Compute fine-grained similarity for the candidates and rerank. | |
| """ | |
| spectrum = spectrum.to(self.device) | |
| with torch.no_grad(): | |
| z_spec_tokens = self.spectra_encoder(spectrum) | |
| scores = [] | |
| for mol_id in candidate_ids: | |
| if mol_id in self.node_cache: | |
| mol_tokens = self.node_cache[mol_id] | |
| elif mol_id in self.molecule_db: | |
| mol = self.molecule_db[mol_id].to(self.device) | |
| mol_tokens = self.molecule_encoder(mol) | |
| else: | |
| mol = self.smiles_preprocess(self.smiles_dict[mol_id]) | |
| mol = mol.to(self.device) | |
| mol_tokens = self.molecule_encoder(mol) | |
| s = self.fine_similarity_fn(z_spec_tokens, mol_tokens).item() | |
| scores.append((mol_id, s)) | |
| scores.sort(key=lambda x: x[1], reverse=True) | |
| return scores[:top_k] | |
| def search( | |
| self, | |
| spectrum, | |
| coarse_k: int = 256, | |
| fine_k: int = 50, | |
| subset: Optional[Dict[str, str]] = None, | |
| ): | |
| """ | |
| Full two-stage search pipeline with optional subset filtering. | |
| """ | |
| candidate_ids, _ = self.coarse_search(spectrum, top_k=coarse_k, subset=subset) | |
| if len(candidate_ids) == 0: | |
| return [] | |
| ranked = self.fine_rerank(spectrum, candidate_ids, top_k=fine_k) | |
| return ranked | |
| if __name__ == "__main__": | |
| import sys | |
| sys.path.insert(0, "/data/yzhouc01/FILIP-MS") | |
| from flare.utils.data import get_spec_featurizer, get_mol_featurizer | |
| from flare.utils.models import get_model | |
| from flare.utils.mol_search import SpectraMoleculeRetriever | |
| from flare.utils.general import filip_similarity_single | |
| import yaml | |
| metadata = { | |
| "class": { | |
| "lipid": ["mol1", "mol2"], | |
| "peptide": ["mol3"] | |
| }, | |
| "pathway": { | |
| "beta-oxidation": ["mol1"], | |
| "glycolysis": ["mol2", "mol3"] | |
| } | |
| } | |
| smiles_dict = { | |
| "mol1": "CCO", | |
| "mol2": "CCN", | |
| "mol3": "CCC" | |
| } | |
| # Load model and data | |
| param_pth = '/data/yzhouc01/cancer/flare.yaml' | |
| with open(param_pth) as f: | |
| params = yaml.load(f, Loader=yaml.FullLoader) | |
| spec_featurizer = get_spec_featurizer(params['spectra_view'], params) | |
| mol_featurizer = get_mol_featurizer(params['molecule_view'], params) | |
| # load model | |
| checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250930_optimized_flare_42/epoch=1959-train_loss=0.08.ckpt" | |
| params['checkpoint_pth'] = checkpoint_pth | |
| model = get_model(params['model'], params) | |
| specMolRetriever = SpectraMoleculeRetriever( | |
| molecule_encoder=model.mol_enc_model, | |
| spectra_encoder=model.spec_enc_model, | |
| fine_similarity_fn=filip_similarity_single, | |
| smiles_preprocess=mol_featurizer | |
| ) | |
| specMolRetriever.build_database(smiles_dict, metadata=metadata, cache_nodes=True) | |
| # Filter search to molecules in a specific pathway | |
| # results = specMolRetriever.search(spectrum, subset={"pathway": "beta-oxidation"}) | |
| # for mol_id, score in results[:10]: | |
| # print(f"{mol_id}: {score:.3f}") | |