Spaces:
Running
Running
| import logging | |
| import pathlib | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| from gt4sd.properties.molecules import MOLECULE_PROPERTY_PREDICTOR_FACTORY | |
| from utils import draw_grid_predict | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(logging.NullHandler()) | |
| MOLFORMER_VERSIONS = { | |
| "molformer_classification": ["bace", "bbbp", "hiv"], | |
| "molformer_regression": [ | |
| "alpha", | |
| "cv", | |
| "g298", | |
| "gap", | |
| "h298", | |
| "homo", | |
| "lipo", | |
| "lumo", | |
| "mu", | |
| "r2", | |
| "u0", | |
| ], | |
| "molformer_multitask_classification": ["clintox", "sider", "tox21"], | |
| } | |
| REMOVE = ["docking", "docking_tdc", "molecule_one", "askcos", "plogp"] | |
| REMOVE.extend(["similarity_seed", "activity_against_target", "organtox"]) | |
| REMOVE.extend(MOLFORMER_VERSIONS.keys()) | |
| MODEL_PROP_DESCRIPTION = { | |
| "Tox21": "NR-AR, NR-AR-LBD, NR-AhR, NR-Aromatase, NR-ER, NR-ER-LBD, NR-PPAR-gamma, SR-ARE, SR-ATAD5, SR-HSE, SR-MMP, SR-p53", | |
| "Sider": "Hepatobiliary disorders,Metabolism and nutrition disorders,Product issues,Eye disorders,Investigations,Musculoskeletal disorders,Gastrointestinal disorders,Social circumstances,Immune system disorders,Reproductive system and breast disorders,Bening & malignant,General disorders,Endocrine disorders,Surgical & medical procedures,Vascular disorders,Blood & lymphatic disorders,Skin & subcutaneous disorders,Congenital & genetic disorders,Infections,Respiratory & thoracic disorders,Psychiatric disorders,Renal & urinary disorders,Pregnancy conditions,Ear disorders,Cardiac disorders,Nervous system disorders,Injury & procedural complications", | |
| "Clintox": "FDA approval, Clinical trial failure", | |
| } | |
| def main(property: str, smiles: str, smiles_file: str): | |
| if "Molformer" in property: | |
| version = property.split(" ")[-1].split("(")[-1].split(")")[0] | |
| property = property.split(" ")[0] | |
| algo, config = MOLECULE_PROPERTY_PREDICTOR_FACTORY[property.lower()] | |
| kwargs = ( | |
| {"algorithm_version": "v0"} if property in MODEL_PROP_DESCRIPTION.keys() else {} | |
| ) | |
| if property.lower() in MOLFORMER_VERSIONS.keys(): | |
| kwargs["algorithm_version"] = version | |
| model = algo(config(**kwargs)) | |
| if smiles != "" and smiles_file is not None: | |
| raise ValueError("Pass either smiles or smiles_file, not both.") | |
| elif smiles != "": | |
| smiles = [smiles] | |
| elif smiles_file is not None: | |
| smiles = pd.read_csv(smiles_file.name, header=None, sep="\t")[0].tolist() | |
| props = np.array(list(map(model, smiles))).round(2) | |
| # Expand to 2D array if needed | |
| if len(props.shape) == 1: | |
| props = np.expand_dims(np.array(props), -1) | |
| if property in MODEL_PROP_DESCRIPTION.keys(): | |
| property_names = MODEL_PROP_DESCRIPTION[property].split(",") | |
| else: | |
| property_names = [property] | |
| return draw_grid_predict( | |
| smiles, props, property_names=property_names, domain="Molecules" | |
| ) | |
| if __name__ == "__main__": | |
| # Preparation (retrieve all available algorithms) | |
| properties = list(MOLECULE_PROPERTY_PREDICTOR_FACTORY.keys())[::-1] | |
| for prop in REMOVE: | |
| prop_to_idx = dict(zip(properties, range(len(properties)))) | |
| properties.pop(prop_to_idx[prop]) | |
| properties = list(map(lambda x: x.capitalize(), properties)) | |
| # MolFormer options | |
| for key in MOLFORMER_VERSIONS.keys(): | |
| properties.extend( | |
| [f"{key.capitalize()} ({version})" for version in MOLFORMER_VERSIONS[key]] | |
| ) | |
| # Load metadata | |
| metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") | |
| examples = [ | |
| ["Qed", "", str(metadata_root.joinpath("examples.smi"))], | |
| [ | |
| "Esol", | |
| "CN1CCN(CCCOc2ccc(N3C(=O)C(=Cc4ccc(Oc5ccc([N+](=O)[O-])cc5)cc4)SC3=S)cc2)CC1", | |
| None, | |
| ], | |
| ] | |
| with open(metadata_root.joinpath("article.md"), "r") as f: | |
| article = f.read() | |
| with open(metadata_root.joinpath("description.md"), "r") as f: | |
| description = f.read() | |
| demo = gr.Interface( | |
| fn=main, | |
| title="Molecular properties", | |
| inputs=[ | |
| gr.Dropdown(properties, label="Property", value="Scscore"), | |
| gr.Textbox( | |
| label="Single SMILES", | |
| placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1", | |
| lines=1, | |
| ), | |
| gr.File( | |
| file_types=[".smi"], | |
| label="Multiple SMILES (tab-separated, `.smi` file)", | |
| ), | |
| ], | |
| outputs=gr.HTML(label="Output"), | |
| article=article, | |
| description=description, | |
| examples=examples, | |
| ) | |
| demo.launch(debug=True, show_error=True, share=True) | |