import pandas as pd from tqdm import tqdm from rdkit import Chem import multiprocessing as mp from tqdm import tqdm import numpy as np import sys import os parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) if parent_dir not in sys.path: sys.path.insert(0, parent_dir) database_to_path = {'fdb':"/data/yzhouc01/molecule_data/foodb_2020_04_07_csv/Compound.csv", 'hmdb':"/data/yzhouc01/molecule_data/metabolites-2025-09-18.csv", 'spectra_db':"/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex_processed.tsv", 'bio_db':"/data/yzhouc01/molecule_data/bio_2023_07_11_smiles.csv", 'coconut':"/data/yzhouc01/molecule_data/coconut_csv-05-2025.csv"} db_to_mass_col = {'fdb':'exact_molecular_weight', 'hmdb':'MONO_MASS', 'spectra_db':'exact_molecular_weight', 'bio_db':'exact_molecular_weight', 'coconut':'exact_molecular_weight'} db_to_smiles_col = {'fdb':'CANONICAL_SMILES', 'hmdb':'CANONICAL_SMILES', 'spectra_db':'CANONICAL_SMILES', 'bio_db':'canonical_smiles', 'coconut':'rdkit_canonical_smiles'} _worker_instance = None def _init_worker(databases, threshold): """Run once per worker process to initialize shared CandidateAssignment.""" global _worker_instance _worker_instance = CandidateAssignment(databases, threshold) def _worker_retrieve_candidates(parent_mass): """Use the global CandidateAssignment instance inside each worker.""" return _worker_instance.retrieve_candidates(parent_mass) _worker_instance = None def _init_worker(databases, threshold): """Initialize global CandidateAssignment in each worker (silent).""" global _worker_instance _worker_instance = CandidateAssignment(databases, threshold, verbose=False) def _worker_retrieve_candidates(parent_mass): """Retrieve candidates using the worker's global CandidateAssignment.""" return _worker_instance.retrieve_candidates(parent_mass) class CandidateAssignment: def __init__(self, databases=None, threshold=0.01, verbose=True): self.threshold = threshold self.databases = [] self.verbose = verbose for db in databases: if db not in database_to_path: raise ValueError( f"Database {db} not recognized. Available: {list(database_to_path.keys())}" ) if not os.path.exists(database_to_path[db]): raise ValueError(f"Database file for {db} not found at {database_to_path[db]}") self.databases.append(db) # Only print in main process if self.verbose and mp.current_process().name == "MainProcess": print(f"[{os.getpid()}] Loading databases: {self.databases}") self.db_dfs = {} self._load_databases() def _load_databases(self): for db in self.databases: path = database_to_path[db] if path.endswith("tsv"): df = pd.read_csv(path, sep="\t", low_memory=False) elif path.endswith("csv"): df = pd.read_csv(path, low_memory=False) else: if self.verbose and mp.current_process().name == "MainProcess": print(f"Unable to load database: {db}") continue # make sure required columns exist required_cols = [db_to_mass_col[db], db_to_smiles_col[db]] for col in required_cols: if col not in df.columns: raise ValueError(f"Column {col} not found in database {db}. {db} columns: {df.columns.tolist()}") # convert to proper types df[db_to_mass_col[db]] = pd.to_numeric(df[db_to_mass_col[db]], errors='coerce') self.db_dfs[db] = df # Only print in main process if self.verbose and mp.current_process().name == "MainProcess": print(f"[{os.getpid()}] Loaded {db} with {len(df)} entries.") def retrieve_candidates(self, parent_mass): """Retrieve SMILES candidates for a single parent mass.""" ub = parent_mass + self.threshold lb = parent_mass - self.threshold smiles_list = [] for db_name, df in self.db_dfs.items(): select_rows = df[ (df[db_to_mass_col[db_name]] >= lb) & (df[db_to_mass_col[db_name]] <= ub) ] smiles_list.extend(select_rows[db_to_smiles_col[db_name]].tolist()) smiles_list = list(set(smiles_list)) return parent_mass, smiles_list def retrieve_candidates_batch(self, parent_masses, n_workers=25, chunksize=10): """Parallel batch retrieval with silent workers.""" with mp.Pool( processes=n_workers, initializer=_init_worker, initargs=(self.databases, self.threshold), ) as pool: results = list( tqdm( pool.imap(_worker_retrieve_candidates, parent_masses, chunksize=chunksize), total=len(parent_masses), desc="Retrieving candidates", ) ) return {r[0]: r[1] for r in results} # P_TBL = Chem.GetPeriodicTable() # ELECTRON_MASS = 0.00054858 # VALID_ELEMENTS = [ # "C", # "H", # "As", # "B", # "Br", # "Cl", # "Co", # "F", # "Fe", # "I", # "K", # "N", # "Na", # "O", # "P", # "S", # "Se", # "Si", # ] # VALID_MONO_MASSES = np.array( # [P_TBL.GetMostCommonIsotopeMass(i) for i in VALID_ELEMENTS] # ) # CHEM_MASSES = VALID_MONO_MASSES[:, None] # ELEMENT_TO_MASS = dict(zip(VALID_ELEMENTS, CHEM_MASSES.squeeze())) # adduct_to_mass = { # "[M+H]+": ELEMENT_TO_MASS["H"] - ELECTRON_MASS, # "[M+Na]+": ELEMENT_TO_MASS["Na"] - ELECTRON_MASS, # "[M+K]+": ELEMENT_TO_MASS["K"] - ELECTRON_MASS, # "[M-H2O+H]+": -ELEMENT_TO_MASS["O"] - ELEMENT_TO_MASS["H"] - ELECTRON_MASS, # "[M+H3N+H]+": ELEMENT_TO_MASS["N"] + ELEMENT_TO_MASS["H"] * 4 - ELECTRON_MASS, # "[M]+": 0 - ELECTRON_MASS, # "[M-H4O2+H]+": -ELEMENT_TO_MASS["O"] * 2 - ELEMENT_TO_MASS["H"] * 3 - ELECTRON_MASS, # "[M-H]-": ELEMENT_TO_MASS["H"] + ELECTRON_MASS, # "[M+H2O+H]+":ELEMENT_TO_MASS["O"] * 2 + ELEMENT_TO_MASS["H"] * 2 - ELECTRON_MASS, # } # def calculate_parent_mass(precursor_mz, adduct): # if adduct not in adduct_to_mass: # print(f'{adduct} not supported, returning original precursor_mz') # return precursor_mz + adduct_to_mass[adduct] if __name__ == "__main__": # get_mol_mass_for_combined() ca = CandidateAssignment(databases=['hmdb']) candidates = ca.retrieve_candidates(parent_mass=180.0634, threshold=0.01) print(candidates)