import json import numpy as np from rdkit import Chem, DataStructs from rdkit.Chem import AllChem, Descriptors, MACCSkeys from rdkit.Chem import rdFingerprintGenerator from rdkit.Chem.FilterCatalog import FilterCatalog, FilterCatalogParams from rdkit.Chem.MolStandardize import rdMolStandardize TOX21_TARGETS = [ "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53", ] USED_200_DESCR = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, ] REFERENCE_LIGANDS = { "NR-AR": [ ("testosterone", "CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C"), ("dihydrotestosterone", "CC12CCC3C(C1CCC2O)CCC4CC(=O)CCC34C"), ("methyltrienolone", "CC12CCC3C(C1CCC2O)CCC4=CC(=O)C=CC34C"), ("flutamide", "CC(C)C(=O)Nc1ccc(c(c1)C(F)(F)F)[N+](=O)[O-]"), ("bicalutamide", "CC(CS(=O)(=O)c1ccc(F)cc1)(O)C(=O)Nc1ccc(C#N)c(c1)C(F)(F)F"), ("enzalutamide", "CNC(=O)c1ccc(N2C(=S)N(c3ccc(C#N)c(C(F)(F)F)c3)C(=O)C2(C)C)cc1F"), ], "NR-AR-LBD": [ ("testosterone", "CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C"), ("dihydrotestosterone", "CC12CCC3C(C1CCC2O)CCC4CC(=O)CCC34C"), ("bicalutamide", "CC(CS(=O)(=O)c1ccc(F)cc1)(O)C(=O)Nc1ccc(C#N)c(c1)C(F)(F)F"), ], "NR-AhR": [ ("tcdd", "Clc1cc2Oc3cc(Cl)c(Cl)cc3Oc2cc1Cl"), ("benzo_a_pyrene", "c1ccc2c(c1)cc3ccc4cccc5ccc2c3c45"), ("beta_naphthoflavone", "O=c1cc(-c2ccc3ccccc3c2)oc2ccc3ccccc3c12"), ("indirubin", "O=C1Nc2ccccc2C1=C1C(=O)Nc2ccccc21"), ], "NR-Aromatase": [ ("exemestane", "CC12CCC3C(C1CC(=C)C2=O)CCC4=CC(=O)C=CC34C"), ("letrozole", "N#Cc1ccc(Cn2cncn2)c(c1)c1ccc(C#N)cc1"), ("anastrozole", "CC(C)(C#N)c1cc(Cn2cncn2)cc(c1)C(C)(C)C#N"), ("androstenedione", "CC12CCC3C(C1CCC2=O)CCC4=CC(=O)CCC34C"), ], "NR-ER": [ ("estradiol", "CC12CCC3c4ccc(O)cc4CCC3C1CCC2O"), ("diethylstilbestrol", "CCC(=C(CC)c1ccc(O)cc1)c1ccc(O)cc1"), ("tamoxifen", "CCC(=C(c1ccccc1)c1ccc(OCCN(C)C)cc1)c1ccccc1"), ("genistein", "Oc1ccc(cc1)C1=COc2cc(O)cc(O)c2C1=O"), ("raloxifene", "Oc1ccc(cc1)c1sc2cc(O)ccc2c1C(=O)c1ccc(OCCN2CCCCC2)cc1"), ], "NR-ER-LBD": [ ("estradiol", "CC12CCC3c4ccc(O)cc4CCC3C1CCC2O"), ("diethylstilbestrol", "CCC(=C(CC)c1ccc(O)cc1)c1ccc(O)cc1"), ("raloxifene", "Oc1ccc(cc1)c1sc2cc(O)ccc2c1C(=O)c1ccc(OCCN2CCCCC2)cc1"), ], "NR-PPAR-gamma": [ ("rosiglitazone", "CN(CCOc1ccc(CC2SC(=O)NC2=O)cc1)c1ccccn1"), ("pioglitazone", "CCc1ccc(CCOc2ccc(CC3SC(=O)NC3=O)cc2)nc1"), ("troglitazone", "Cc1c(C)c2OC(C)(C)CCc2c(C)c1Oc1ccc(CC2SC(=O)NC2=O)cc1"), ], "SR-ARE": [ ("sulforaphane", "CS(=O)CCCCN=C=S"), ("tert_butylhydroquinone", "CC(C)(C)c1cc(O)ccc1O"), ("curcumin", "COc1cc(C=CC(=O)CC(=O)C=Cc2ccc(O)c(OC)c2)ccc1O"), ], "SR-ATAD5": [ ("camptothecin", "CCC1(O)C(=O)OCc2c1cc3n(c2=O)c1ccccc1nc3"), ("etoposide", "COc1cc(cc(OC)c1O)C1C2C(COC2=O)C(OC2OC3COC(C)OC3C(O)C2O)c2cc3OCOc3cc12"), ], "SR-HSE": [ ("geldanamycin", "COC1CC(C)CC2=C(NCC=C(C)C(OC)C(C)C(OC(N)=O)C(C)C=C(C)C=C(C)C(=O)N1)C(=O)C=C(N)C2=O"), ("ganetespib", "CC(C)c1cc(-c2n[nH]c(=O)n2-c2ccc3c(ccn3C)c2)c(O)cc1O"), ], "SR-MMP": [ ("cccp", "N#CC(=Cc1ccc([N+](=O)[O-])cc1)C#N"), ("fccp", "N#CC(=Cc1ccc(cc1)C(F)(F)F)C#N"), ("rotenone", "COc1cc2C3CC(C)OC3c3ccc4OC5OCCC5c4c3c2cc1OC"), ("antimycin_a", "CCCCCC(C)C(OC(=O)c1ccccc1N)C(NC(=O)c1cccc(NC=O)c1O)C(C)O"), ], "SR-p53": [ ("nutlin_3", "COc1ccc(c(OC)c1)C1N(C(=O)C(N1c1ccc(Cl)cc1)c1ccc(Cl)cc1)C1CCNCC1"), ("doxorubicin", "COc1cccc2c1C(=O)c1c(O)c3CC(O)(CC(OC4CC(N)C(O)C(C)O4)c3c(O)c1C2=O)C(=O)CO"), ], } class EnhancedFeatureExtractor: def __init__( self, toxicophores_path=None, db_ligands_path=None, use_rdkit_filters=True, use_similarity=True, use_db_ligands=True, ecfp_radius=3, ecfp_bits=8192, sim_radius=2, sim_bits=2048, ): self.toxicophores_path = toxicophores_path self.db_ligands_path = db_ligands_path self.use_rdkit_filters = use_rdkit_filters self.use_similarity = use_similarity self.use_db_ligands = use_db_ligands self.ecfp_radius = ecfp_radius self.ecfp_bits = ecfp_bits self.sim_radius = sim_radius self.sim_bits = sim_bits self._toxicophore_patterns = None self._filter_catalogs = None self._ref_fps = None self._db_ligand_fps = None self._standardizer = None def _get_standardizer(self): if self._standardizer is None: self._standardizer = _Standardizer() return self._standardizer def _load_toxicophores(self): if self._toxicophore_patterns is None: if self.toxicophores_path: with open(self.toxicophores_path) as f: data = json.load(f) self._toxicophore_patterns = [] for name, smarts in data: pat = Chem.MolFromSmarts(smarts) if pat: self._toxicophore_patterns.append((name, pat)) return self._toxicophore_patterns def _load_filter_catalogs(self): if self._filter_catalogs is None: self._filter_catalogs = {} for name, cat_type in [ ("PAINS", FilterCatalogParams.FilterCatalogs.PAINS), ("BRENK", FilterCatalogParams.FilterCatalogs.BRENK), ("NIH", FilterCatalogParams.FilterCatalogs.NIH), ("ZINC", FilterCatalogParams.FilterCatalogs.ZINC), ]: params = FilterCatalogParams() params.AddCatalog(cat_type) self._filter_catalogs[name] = FilterCatalog(params) return self._filter_catalogs def _load_ref_fps(self): if self._ref_fps is None: self._ref_fps = {} gen = rdFingerprintGenerator.GetMorganGenerator( radius=self.sim_radius, fpSize=self.sim_bits ) for target, ligands in REFERENCE_LIGANDS.items(): self._ref_fps[target] = [] for name, smi in ligands: mol = Chem.MolFromSmiles(smi) if mol: fp = gen.GetFingerprint(mol) self._ref_fps[target].append((name, fp)) return self._ref_fps def _load_db_ligand_fps(self): if self._db_ligand_fps is None and self.db_ligands_path: with open(self.db_ligands_path) as f: db_ligands = json.load(f) gen = rdFingerprintGenerator.GetMorganGenerator( radius=self.sim_radius, fpSize=self.sim_bits ) self._db_ligand_fps = {} for target in TOX21_TARGETS: if target not in db_ligands: continue self._db_ligand_fps[target] = [] for lig in db_ligands[target][:10]: smi = lig.get("smiles", "") name = lig.get("name", "unknown")[:20] mol = Chem.MolFromSmiles(smi) if mol: fp = gen.GetFingerprint(mol) self._db_ligand_fps[target].append((name, fp)) return self._db_ligand_fps def extract_features(self, smiles_list): standardizer = self._get_standardizer() mols = [] valid_mask = [] for smi in smiles_list: mol = Chem.MolFromSmiles(smi) if mol is None: valid_mask.append(False) continue std_mol, _ = standardizer.standardize_mol(mol) if std_mol is None: valid_mask.append(False) continue mols.append(std_mol) valid_mask.append(True) valid_mask = np.array(valid_mask) n_total = len(smiles_list) n_valid = len(mols) features = {} ecfps = self._compute_ecfp(mols) features["ecfps"] = self._fill(ecfps, valid_mask, n_total) maccs = self._compute_maccs(mols) features["maccs"] = self._fill(maccs, valid_mask, n_total) rdkit_descrs = self._compute_rdkit_descriptors(mols) features["rdkit_descrs"] = self._fill(rdkit_descrs, valid_mask, n_total) if self.toxicophores_path: tox = self._compute_toxicophore_features(mols) features["tox"] = self._fill(tox, valid_mask, n_total) if self.use_rdkit_filters: filters = self._compute_rdkit_filter_features(mols) features["rdkit_filters"] = self._fill(filters, valid_mask, n_total) if self.use_similarity: sim = self._compute_similarity_features(mols) features["similarity"] = self._fill(sim, valid_mask, n_total) max_sim = self._compute_max_similarity_features(mols) features["max_similarity"] = self._fill(max_sim, valid_mask, n_total) if self.use_db_ligands and self.db_ligands_path: db_sim = self._compute_db_ligand_similarity(mols) features["db_similarity"] = self._fill(db_sim, valid_mask, n_total) return features, valid_mask def _fill(self, features, mask, n_total): n_features = features.shape[1] if len(features.shape) > 1 else 1 filled = np.full((n_total, n_features), np.nan, dtype=np.float32) filled[mask] = features return filled def _compute_ecfp(self, mols): ecfps = [] gen = rdFingerprintGenerator.GetMorganGenerator( countSimulation=True, fpSize=self.ecfp_bits, radius=self.ecfp_radius ) for mol in mols: fp = gen.GetCountFingerprint(mol) arr = np.zeros((self.ecfp_bits,), dtype=np.float32) DataStructs.ConvertToNumpyArray(fp, arr) ecfps.append(arr) return np.array(ecfps) def _compute_maccs(self, mols): maccs = [] for mol in mols: fp = MACCSkeys.GenMACCSKeys(mol) arr = np.zeros((167,), dtype=np.float32) DataStructs.ConvertToNumpyArray(fp, arr) maccs.append(arr) return np.array(maccs) def _compute_rdkit_descriptors(self, mols): descrs_list = [] for mol in mols: descrs = [] for _, fn in Descriptors._descList: try: val = fn(mol) if val is None or np.isnan(val) or np.isinf(val): val = 0.0 except Exception: val = 0.0 descrs.append(val) descrs = np.array(descrs)[USED_200_DESCR] descrs_list.append(descrs) return np.array(descrs_list, dtype=np.float32) def _compute_toxicophore_features(self, mols): patterns = self._load_toxicophores() features = np.zeros((len(mols), len(patterns)), dtype=np.float32) for i, mol in enumerate(mols): for j, (name, pat) in enumerate(patterns): if mol.HasSubstructMatch(pat): features[i, j] = 1.0 return features def _compute_rdkit_filter_features(self, mols): catalogs = self._load_filter_catalogs() n_features = sum(cat.GetNumEntries() for cat in catalogs.values()) features = np.zeros((len(mols), n_features), dtype=np.float32) for mol_idx, mol in enumerate(mols): feat_idx = 0 for cat_name, catalog in catalogs.items(): for i in range(catalog.GetNumEntries()): entry = catalog.GetEntryWithIdx(i) if entry.HasFilterMatch(mol): features[mol_idx, feat_idx] = 1.0 feat_idx += 1 return features def _compute_similarity_features(self, mols): ref_fps = self._load_ref_fps() n_features = sum(len(fps) for fps in ref_fps.values()) features = np.zeros((len(mols), n_features), dtype=np.float32) gen = rdFingerprintGenerator.GetMorganGenerator( radius=self.sim_radius, fpSize=self.sim_bits ) for mol_idx, mol in enumerate(mols): mol_fp = gen.GetFingerprint(mol) feat_idx = 0 for target in REFERENCE_LIGANDS.keys(): for name, ref_fp in ref_fps[target]: features[mol_idx, feat_idx] = DataStructs.TanimotoSimilarity( mol_fp, ref_fp ) feat_idx += 1 return features def _compute_max_similarity_features(self, mols): ref_fps = self._load_ref_fps() features = np.zeros((len(mols), len(TOX21_TARGETS)), dtype=np.float32) gen = rdFingerprintGenerator.GetMorganGenerator( radius=self.sim_radius, fpSize=self.sim_bits ) for mol_idx, mol in enumerate(mols): mol_fp = gen.GetFingerprint(mol) for target_idx, target in enumerate(TOX21_TARGETS): if target in ref_fps and ref_fps[target]: sims = [ DataStructs.TanimotoSimilarity(mol_fp, fp) for _, fp in ref_fps[target] ] features[mol_idx, target_idx] = max(sims) return features def _compute_db_ligand_similarity(self, mols): db_fps = self._load_db_ligand_fps() if not db_fps: return np.zeros((len(mols), 0), dtype=np.float32) n_features = sum(len(fps) for fps in db_fps.values()) features = np.zeros((len(mols), n_features), dtype=np.float32) gen = rdFingerprintGenerator.GetMorganGenerator( radius=self.sim_radius, fpSize=self.sim_bits ) for mol_idx, mol in enumerate(mols): mol_fp = gen.GetFingerprint(mol) feat_idx = 0 for target in TOX21_TARGETS: if target not in db_fps: continue for name, ref_fp in db_fps[target]: features[mol_idx, feat_idx] = DataStructs.TanimotoSimilarity( mol_fp, ref_fp ) feat_idx += 1 return features class _Standardizer: def __init__(self): self._taut_enumerator = None self._uncharger = None self._lfrag_chooser = None @property def taut_enumerator(self): if self._taut_enumerator is None: self._taut_enumerator = rdMolStandardize.TautomerEnumerator() return self._taut_enumerator @property def uncharger(self): if self._uncharger is None: self._uncharger = rdMolStandardize.Uncharger() return self._uncharger @property def lfrag_chooser(self): if self._lfrag_chooser is None: self._lfrag_chooser = rdMolStandardize.LargestFragmentChooser() return self._lfrag_chooser def standardize_mol(self, mol_in): try: params = Chem.RemoveHsParameters() params.removeAndTrackIsotopes = True mol = Chem.RemoveHs(mol_in, params, sanitize=False) mol = rdMolStandardize.Cleanup(mol) Chem.SanitizeMol(mol) Chem.AssignStereochemistry(mol) mol = self.lfrag_chooser.choose(mol) mol = self.uncharger.uncharge(mol) Chem.SanitizeMol(mol) mol = Chem.RemoveHs(Chem.AddHs(mol)) can_smiles = Chem.MolToSmiles(mol) return mol, can_smiles except Exception: return None, None def get_feature_counts(toxicophores_path=None, db_ligands_path=None): counts = { "ecfps": 8192, "maccs": 167, "rdkit_descrs": 208, } if toxicophores_path: with open(toxicophores_path) as f: tox_data = json.load(f) counts["tox"] = len(tox_data) rdkit_count = 0 for cat_type in [ FilterCatalogParams.FilterCatalogs.PAINS, FilterCatalogParams.FilterCatalogs.BRENK, FilterCatalogParams.FilterCatalogs.NIH, FilterCatalogParams.FilterCatalogs.ZINC, ]: params = FilterCatalogParams() params.AddCatalog(cat_type) rdkit_count += FilterCatalog(params).GetNumEntries() counts["rdkit_filters"] = rdkit_count counts["similarity"] = sum(len(ligs) for ligs in REFERENCE_LIGANDS.values()) counts["max_similarity"] = len(TOX21_TARGETS) if db_ligands_path: with open(db_ligands_path) as f: db_ligands = json.load(f) counts["db_similarity"] = sum(min(len(v), 10) for v in db_ligands.values()) return counts