Spaces:
Running
Running
| import json | |
| import numpy as np | |
| from rdkit import Chem, DataStructs | |
| from rdkit.Chem import AllChem, Descriptors, MACCSkeys | |
| from rdkit.Chem import rdFingerprintGenerator | |
| from rdkit.Chem.FilterCatalog import FilterCatalog, FilterCatalogParams | |
| from rdkit.Chem.MolStandardize import rdMolStandardize | |
| TOX21_TARGETS = [ | |
| "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", | |
| "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53", | |
| ] | |
| USED_200_DESCR = [ | |
| 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 25, 26, 27, | |
| 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, | |
| 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, | |
| 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, | |
| 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, | |
| 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, | |
| 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, | |
| 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, | |
| 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, | |
| 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, | |
| 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, | |
| 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, | |
| 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, | |
| ] | |
| REFERENCE_LIGANDS = { | |
| "NR-AR": [ | |
| ("testosterone", "CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C"), | |
| ("dihydrotestosterone", "CC12CCC3C(C1CCC2O)CCC4CC(=O)CCC34C"), | |
| ("methyltrienolone", "CC12CCC3C(C1CCC2O)CCC4=CC(=O)C=CC34C"), | |
| ("flutamide", "CC(C)C(=O)Nc1ccc(c(c1)C(F)(F)F)[N+](=O)[O-]"), | |
| ("bicalutamide", "CC(CS(=O)(=O)c1ccc(F)cc1)(O)C(=O)Nc1ccc(C#N)c(c1)C(F)(F)F"), | |
| ("enzalutamide", "CNC(=O)c1ccc(N2C(=S)N(c3ccc(C#N)c(C(F)(F)F)c3)C(=O)C2(C)C)cc1F"), | |
| ], | |
| "NR-AR-LBD": [ | |
| ("testosterone", "CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C"), | |
| ("dihydrotestosterone", "CC12CCC3C(C1CCC2O)CCC4CC(=O)CCC34C"), | |
| ("bicalutamide", "CC(CS(=O)(=O)c1ccc(F)cc1)(O)C(=O)Nc1ccc(C#N)c(c1)C(F)(F)F"), | |
| ], | |
| "NR-AhR": [ | |
| ("tcdd", "Clc1cc2Oc3cc(Cl)c(Cl)cc3Oc2cc1Cl"), | |
| ("benzo_a_pyrene", "c1ccc2c(c1)cc3ccc4cccc5ccc2c3c45"), | |
| ("beta_naphthoflavone", "O=c1cc(-c2ccc3ccccc3c2)oc2ccc3ccccc3c12"), | |
| ("indirubin", "O=C1Nc2ccccc2C1=C1C(=O)Nc2ccccc21"), | |
| ], | |
| "NR-Aromatase": [ | |
| ("exemestane", "CC12CCC3C(C1CC(=C)C2=O)CCC4=CC(=O)C=CC34C"), | |
| ("letrozole", "N#Cc1ccc(Cn2cncn2)c(c1)c1ccc(C#N)cc1"), | |
| ("anastrozole", "CC(C)(C#N)c1cc(Cn2cncn2)cc(c1)C(C)(C)C#N"), | |
| ("androstenedione", "CC12CCC3C(C1CCC2=O)CCC4=CC(=O)CCC34C"), | |
| ], | |
| "NR-ER": [ | |
| ("estradiol", "CC12CCC3c4ccc(O)cc4CCC3C1CCC2O"), | |
| ("diethylstilbestrol", "CCC(=C(CC)c1ccc(O)cc1)c1ccc(O)cc1"), | |
| ("tamoxifen", "CCC(=C(c1ccccc1)c1ccc(OCCN(C)C)cc1)c1ccccc1"), | |
| ("genistein", "Oc1ccc(cc1)C1=COc2cc(O)cc(O)c2C1=O"), | |
| ("raloxifene", "Oc1ccc(cc1)c1sc2cc(O)ccc2c1C(=O)c1ccc(OCCN2CCCCC2)cc1"), | |
| ], | |
| "NR-ER-LBD": [ | |
| ("estradiol", "CC12CCC3c4ccc(O)cc4CCC3C1CCC2O"), | |
| ("diethylstilbestrol", "CCC(=C(CC)c1ccc(O)cc1)c1ccc(O)cc1"), | |
| ("raloxifene", "Oc1ccc(cc1)c1sc2cc(O)ccc2c1C(=O)c1ccc(OCCN2CCCCC2)cc1"), | |
| ], | |
| "NR-PPAR-gamma": [ | |
| ("rosiglitazone", "CN(CCOc1ccc(CC2SC(=O)NC2=O)cc1)c1ccccn1"), | |
| ("pioglitazone", "CCc1ccc(CCOc2ccc(CC3SC(=O)NC3=O)cc2)nc1"), | |
| ("troglitazone", "Cc1c(C)c2OC(C)(C)CCc2c(C)c1Oc1ccc(CC2SC(=O)NC2=O)cc1"), | |
| ], | |
| "SR-ARE": [ | |
| ("sulforaphane", "CS(=O)CCCCN=C=S"), | |
| ("tert_butylhydroquinone", "CC(C)(C)c1cc(O)ccc1O"), | |
| ("curcumin", "COc1cc(C=CC(=O)CC(=O)C=Cc2ccc(O)c(OC)c2)ccc1O"), | |
| ], | |
| "SR-ATAD5": [ | |
| ("camptothecin", "CCC1(O)C(=O)OCc2c1cc3n(c2=O)c1ccccc1nc3"), | |
| ("etoposide", "COc1cc(cc(OC)c1O)C1C2C(COC2=O)C(OC2OC3COC(C)OC3C(O)C2O)c2cc3OCOc3cc12"), | |
| ], | |
| "SR-HSE": [ | |
| ("geldanamycin", "COC1CC(C)CC2=C(NCC=C(C)C(OC)C(C)C(OC(N)=O)C(C)C=C(C)C=C(C)C(=O)N1)C(=O)C=C(N)C2=O"), | |
| ("ganetespib", "CC(C)c1cc(-c2n[nH]c(=O)n2-c2ccc3c(ccn3C)c2)c(O)cc1O"), | |
| ], | |
| "SR-MMP": [ | |
| ("cccp", "N#CC(=Cc1ccc([N+](=O)[O-])cc1)C#N"), | |
| ("fccp", "N#CC(=Cc1ccc(cc1)C(F)(F)F)C#N"), | |
| ("rotenone", "COc1cc2C3CC(C)OC3c3ccc4OC5OCCC5c4c3c2cc1OC"), | |
| ("antimycin_a", "CCCCCC(C)C(OC(=O)c1ccccc1N)C(NC(=O)c1cccc(NC=O)c1O)C(C)O"), | |
| ], | |
| "SR-p53": [ | |
| ("nutlin_3", "COc1ccc(c(OC)c1)C1N(C(=O)C(N1c1ccc(Cl)cc1)c1ccc(Cl)cc1)C1CCNCC1"), | |
| ("doxorubicin", "COc1cccc2c1C(=O)c1c(O)c3CC(O)(CC(OC4CC(N)C(O)C(C)O4)c3c(O)c1C2=O)C(=O)CO"), | |
| ], | |
| } | |
| class EnhancedFeatureExtractor: | |
| def __init__( | |
| self, | |
| toxicophores_path=None, | |
| db_ligands_path=None, | |
| use_rdkit_filters=True, | |
| use_similarity=True, | |
| use_db_ligands=True, | |
| ecfp_radius=3, | |
| ecfp_bits=8192, | |
| sim_radius=2, | |
| sim_bits=2048, | |
| ): | |
| self.toxicophores_path = toxicophores_path | |
| self.db_ligands_path = db_ligands_path | |
| self.use_rdkit_filters = use_rdkit_filters | |
| self.use_similarity = use_similarity | |
| self.use_db_ligands = use_db_ligands | |
| self.ecfp_radius = ecfp_radius | |
| self.ecfp_bits = ecfp_bits | |
| self.sim_radius = sim_radius | |
| self.sim_bits = sim_bits | |
| self._toxicophore_patterns = None | |
| self._filter_catalogs = None | |
| self._ref_fps = None | |
| self._db_ligand_fps = None | |
| self._standardizer = None | |
| def _get_standardizer(self): | |
| if self._standardizer is None: | |
| self._standardizer = _Standardizer() | |
| return self._standardizer | |
| def _load_toxicophores(self): | |
| if self._toxicophore_patterns is None: | |
| if self.toxicophores_path: | |
| with open(self.toxicophores_path) as f: | |
| data = json.load(f) | |
| self._toxicophore_patterns = [] | |
| for name, smarts in data: | |
| pat = Chem.MolFromSmarts(smarts) | |
| if pat: | |
| self._toxicophore_patterns.append((name, pat)) | |
| return self._toxicophore_patterns | |
| def _load_filter_catalogs(self): | |
| if self._filter_catalogs is None: | |
| self._filter_catalogs = {} | |
| for name, cat_type in [ | |
| ("PAINS", FilterCatalogParams.FilterCatalogs.PAINS), | |
| ("BRENK", FilterCatalogParams.FilterCatalogs.BRENK), | |
| ("NIH", FilterCatalogParams.FilterCatalogs.NIH), | |
| ("ZINC", FilterCatalogParams.FilterCatalogs.ZINC), | |
| ]: | |
| params = FilterCatalogParams() | |
| params.AddCatalog(cat_type) | |
| self._filter_catalogs[name] = FilterCatalog(params) | |
| return self._filter_catalogs | |
| def _load_ref_fps(self): | |
| if self._ref_fps is None: | |
| self._ref_fps = {} | |
| gen = rdFingerprintGenerator.GetMorganGenerator( | |
| radius=self.sim_radius, fpSize=self.sim_bits | |
| ) | |
| for target, ligands in REFERENCE_LIGANDS.items(): | |
| self._ref_fps[target] = [] | |
| for name, smi in ligands: | |
| mol = Chem.MolFromSmiles(smi) | |
| if mol: | |
| fp = gen.GetFingerprint(mol) | |
| self._ref_fps[target].append((name, fp)) | |
| return self._ref_fps | |
| def _load_db_ligand_fps(self): | |
| if self._db_ligand_fps is None and self.db_ligands_path: | |
| with open(self.db_ligands_path) as f: | |
| db_ligands = json.load(f) | |
| gen = rdFingerprintGenerator.GetMorganGenerator( | |
| radius=self.sim_radius, fpSize=self.sim_bits | |
| ) | |
| self._db_ligand_fps = {} | |
| for target in TOX21_TARGETS: | |
| if target not in db_ligands: | |
| continue | |
| self._db_ligand_fps[target] = [] | |
| for lig in db_ligands[target][:10]: | |
| smi = lig.get("smiles", "") | |
| name = lig.get("name", "unknown")[:20] | |
| mol = Chem.MolFromSmiles(smi) | |
| if mol: | |
| fp = gen.GetFingerprint(mol) | |
| self._db_ligand_fps[target].append((name, fp)) | |
| return self._db_ligand_fps | |
| def extract_features(self, smiles_list): | |
| standardizer = self._get_standardizer() | |
| mols = [] | |
| valid_mask = [] | |
| for smi in smiles_list: | |
| mol = Chem.MolFromSmiles(smi) | |
| if mol is None: | |
| valid_mask.append(False) | |
| continue | |
| std_mol, _ = standardizer.standardize_mol(mol) | |
| if std_mol is None: | |
| valid_mask.append(False) | |
| continue | |
| mols.append(std_mol) | |
| valid_mask.append(True) | |
| valid_mask = np.array(valid_mask) | |
| n_total = len(smiles_list) | |
| n_valid = len(mols) | |
| features = {} | |
| ecfps = self._compute_ecfp(mols) | |
| features["ecfps"] = self._fill(ecfps, valid_mask, n_total) | |
| maccs = self._compute_maccs(mols) | |
| features["maccs"] = self._fill(maccs, valid_mask, n_total) | |
| rdkit_descrs = self._compute_rdkit_descriptors(mols) | |
| features["rdkit_descrs"] = self._fill(rdkit_descrs, valid_mask, n_total) | |
| if self.toxicophores_path: | |
| tox = self._compute_toxicophore_features(mols) | |
| features["tox"] = self._fill(tox, valid_mask, n_total) | |
| if self.use_rdkit_filters: | |
| filters = self._compute_rdkit_filter_features(mols) | |
| features["rdkit_filters"] = self._fill(filters, valid_mask, n_total) | |
| if self.use_similarity: | |
| sim = self._compute_similarity_features(mols) | |
| features["similarity"] = self._fill(sim, valid_mask, n_total) | |
| max_sim = self._compute_max_similarity_features(mols) | |
| features["max_similarity"] = self._fill(max_sim, valid_mask, n_total) | |
| if self.use_db_ligands and self.db_ligands_path: | |
| db_sim = self._compute_db_ligand_similarity(mols) | |
| features["db_similarity"] = self._fill(db_sim, valid_mask, n_total) | |
| return features, valid_mask | |
| def _fill(self, features, mask, n_total): | |
| n_features = features.shape[1] if len(features.shape) > 1 else 1 | |
| filled = np.full((n_total, n_features), np.nan, dtype=np.float32) | |
| filled[mask] = features | |
| return filled | |
| def _compute_ecfp(self, mols): | |
| ecfps = [] | |
| gen = rdFingerprintGenerator.GetMorganGenerator( | |
| countSimulation=True, fpSize=self.ecfp_bits, radius=self.ecfp_radius | |
| ) | |
| for mol in mols: | |
| fp = gen.GetCountFingerprint(mol) | |
| arr = np.zeros((self.ecfp_bits,), dtype=np.float32) | |
| DataStructs.ConvertToNumpyArray(fp, arr) | |
| ecfps.append(arr) | |
| return np.array(ecfps) | |
| def _compute_maccs(self, mols): | |
| maccs = [] | |
| for mol in mols: | |
| fp = MACCSkeys.GenMACCSKeys(mol) | |
| arr = np.zeros((167,), dtype=np.float32) | |
| DataStructs.ConvertToNumpyArray(fp, arr) | |
| maccs.append(arr) | |
| return np.array(maccs) | |
| def _compute_rdkit_descriptors(self, mols): | |
| descrs_list = [] | |
| for mol in mols: | |
| descrs = [] | |
| for _, fn in Descriptors._descList: | |
| try: | |
| val = fn(mol) | |
| if val is None or np.isnan(val) or np.isinf(val): | |
| val = 0.0 | |
| except Exception: | |
| val = 0.0 | |
| descrs.append(val) | |
| descrs = np.array(descrs)[USED_200_DESCR] | |
| descrs_list.append(descrs) | |
| return np.array(descrs_list, dtype=np.float32) | |
| def _compute_toxicophore_features(self, mols): | |
| patterns = self._load_toxicophores() | |
| features = np.zeros((len(mols), len(patterns)), dtype=np.float32) | |
| for i, mol in enumerate(mols): | |
| for j, (name, pat) in enumerate(patterns): | |
| if mol.HasSubstructMatch(pat): | |
| features[i, j] = 1.0 | |
| return features | |
| def _compute_rdkit_filter_features(self, mols): | |
| catalogs = self._load_filter_catalogs() | |
| n_features = sum(cat.GetNumEntries() for cat in catalogs.values()) | |
| features = np.zeros((len(mols), n_features), dtype=np.float32) | |
| for mol_idx, mol in enumerate(mols): | |
| feat_idx = 0 | |
| for cat_name, catalog in catalogs.items(): | |
| for i in range(catalog.GetNumEntries()): | |
| entry = catalog.GetEntryWithIdx(i) | |
| if entry.HasFilterMatch(mol): | |
| features[mol_idx, feat_idx] = 1.0 | |
| feat_idx += 1 | |
| return features | |
| def _compute_similarity_features(self, mols): | |
| ref_fps = self._load_ref_fps() | |
| n_features = sum(len(fps) for fps in ref_fps.values()) | |
| features = np.zeros((len(mols), n_features), dtype=np.float32) | |
| gen = rdFingerprintGenerator.GetMorganGenerator( | |
| radius=self.sim_radius, fpSize=self.sim_bits | |
| ) | |
| for mol_idx, mol in enumerate(mols): | |
| mol_fp = gen.GetFingerprint(mol) | |
| feat_idx = 0 | |
| for target in REFERENCE_LIGANDS.keys(): | |
| for name, ref_fp in ref_fps[target]: | |
| features[mol_idx, feat_idx] = DataStructs.TanimotoSimilarity( | |
| mol_fp, ref_fp | |
| ) | |
| feat_idx += 1 | |
| return features | |
| def _compute_max_similarity_features(self, mols): | |
| ref_fps = self._load_ref_fps() | |
| features = np.zeros((len(mols), len(TOX21_TARGETS)), dtype=np.float32) | |
| gen = rdFingerprintGenerator.GetMorganGenerator( | |
| radius=self.sim_radius, fpSize=self.sim_bits | |
| ) | |
| for mol_idx, mol in enumerate(mols): | |
| mol_fp = gen.GetFingerprint(mol) | |
| for target_idx, target in enumerate(TOX21_TARGETS): | |
| if target in ref_fps and ref_fps[target]: | |
| sims = [ | |
| DataStructs.TanimotoSimilarity(mol_fp, fp) | |
| for _, fp in ref_fps[target] | |
| ] | |
| features[mol_idx, target_idx] = max(sims) | |
| return features | |
| def _compute_db_ligand_similarity(self, mols): | |
| db_fps = self._load_db_ligand_fps() | |
| if not db_fps: | |
| return np.zeros((len(mols), 0), dtype=np.float32) | |
| n_features = sum(len(fps) for fps in db_fps.values()) | |
| features = np.zeros((len(mols), n_features), dtype=np.float32) | |
| gen = rdFingerprintGenerator.GetMorganGenerator( | |
| radius=self.sim_radius, fpSize=self.sim_bits | |
| ) | |
| for mol_idx, mol in enumerate(mols): | |
| mol_fp = gen.GetFingerprint(mol) | |
| feat_idx = 0 | |
| for target in TOX21_TARGETS: | |
| if target not in db_fps: | |
| continue | |
| for name, ref_fp in db_fps[target]: | |
| features[mol_idx, feat_idx] = DataStructs.TanimotoSimilarity( | |
| mol_fp, ref_fp | |
| ) | |
| feat_idx += 1 | |
| return features | |
| class _Standardizer: | |
| def __init__(self): | |
| self._taut_enumerator = None | |
| self._uncharger = None | |
| self._lfrag_chooser = None | |
| def taut_enumerator(self): | |
| if self._taut_enumerator is None: | |
| self._taut_enumerator = rdMolStandardize.TautomerEnumerator() | |
| return self._taut_enumerator | |
| def uncharger(self): | |
| if self._uncharger is None: | |
| self._uncharger = rdMolStandardize.Uncharger() | |
| return self._uncharger | |
| def lfrag_chooser(self): | |
| if self._lfrag_chooser is None: | |
| self._lfrag_chooser = rdMolStandardize.LargestFragmentChooser() | |
| return self._lfrag_chooser | |
| def standardize_mol(self, mol_in): | |
| try: | |
| params = Chem.RemoveHsParameters() | |
| params.removeAndTrackIsotopes = True | |
| mol = Chem.RemoveHs(mol_in, params, sanitize=False) | |
| mol = rdMolStandardize.Cleanup(mol) | |
| Chem.SanitizeMol(mol) | |
| Chem.AssignStereochemistry(mol) | |
| mol = self.lfrag_chooser.choose(mol) | |
| mol = self.uncharger.uncharge(mol) | |
| Chem.SanitizeMol(mol) | |
| mol = Chem.RemoveHs(Chem.AddHs(mol)) | |
| can_smiles = Chem.MolToSmiles(mol) | |
| return mol, can_smiles | |
| except Exception: | |
| return None, None | |
| def get_feature_counts(toxicophores_path=None, db_ligands_path=None): | |
| counts = { | |
| "ecfps": 8192, | |
| "maccs": 167, | |
| "rdkit_descrs": 208, | |
| } | |
| if toxicophores_path: | |
| with open(toxicophores_path) as f: | |
| tox_data = json.load(f) | |
| counts["tox"] = len(tox_data) | |
| rdkit_count = 0 | |
| for cat_type in [ | |
| FilterCatalogParams.FilterCatalogs.PAINS, | |
| FilterCatalogParams.FilterCatalogs.BRENK, | |
| FilterCatalogParams.FilterCatalogs.NIH, | |
| FilterCatalogParams.FilterCatalogs.ZINC, | |
| ]: | |
| params = FilterCatalogParams() | |
| params.AddCatalog(cat_type) | |
| rdkit_count += FilterCatalog(params).GetNumEntries() | |
| counts["rdkit_filters"] = rdkit_count | |
| counts["similarity"] = sum(len(ligs) for ligs in REFERENCE_LIGANDS.values()) | |
| counts["max_similarity"] = len(TOX21_TARGETS) | |
| if db_ligands_path: | |
| with open(db_ligands_path) as f: | |
| db_ligands = json.load(f) | |
| counts["db_similarity"] = sum(min(len(v), 10) for v in db_ligands.values()) | |
| return counts | |