Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Drug Discovery Predictor - Deployment Ready for Hugging Face Spaces | |
| """ | |
| import os | |
| import gradio as gr | |
| import joblib | |
| import pandas as pd | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image, ImageDraw, ImageFont | |
| import requests | |
| from io import BytesIO | |
| from functools import partial | |
| import urllib.parse | |
| # --- Global Setup --- | |
| # Caching Setup for downloaded images | |
| CACHE_DIR = "image_cache" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| print(f"Image cache directory created at: {CACHE_DIR}") | |
| # --- 1. Define the PyTorch Model Class --- | |
| # This class definition is required to load the saved PyTorch model state. | |
| class MLPAgent(nn.Module): | |
| def __init__(self, input_dim, num_classes): | |
| super(MLPAgent, self).__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, 128), nn.ReLU(), nn.Dropout(0.3), | |
| nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.2), | |
| nn.Linear(64, num_classes), nn.Softmax(dim=1) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| # --- 2. Load all models and preprocessors --- | |
| # This section runs once when the application starts. | |
| try: | |
| rf_model = joblib.load("models/rf_model.joblib") | |
| scaler = joblib.load("models/scaler.joblib") | |
| le = joblib.load("models/le.joblib") | |
| keras_model = keras.models.load_model("models/keras_mlp.h5") | |
| num_classes = len(le.classes_) | |
| rl_agent = MLPAgent(input_dim=5, num_classes=num_classes) | |
| rl_agent.load_state_dict(torch.load("models/rl_upgraded_agent.pth")) | |
| rl_agent.eval() | |
| print("All models and preprocessors loaded successfully.") | |
| MODELS_LOADED = True | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| MODELS_LOADED = False | |
| # Mapping from protein targets to known chemical compounds for image search | |
| PROTEIN_TO_COMPOUND_MAP = { | |
| "BACE1": "Verubecestat", "HDAC1": "Vorinostat", "EGFR": "Gefitinib", | |
| "DRD2": "Haloperidol", "HIV-1 RT": "Nevirapine", "AMPC": "Cefoxitin", | |
| "MMP-13": "Marimastat" | |
| } | |
| # --- 3. Image Pre-Fetching Logic --- | |
| # To ensure fast performance, we download all images when the app starts. | |
| def pre_fetch_images(): | |
| print("\n--- Starting Image Pre-Fetching ---") | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
| } | |
| for compound_name in set(PROTEIN_TO_COMPOUND_MAP.values()): | |
| sanitized_name = "".join(c for c in compound_name if c.isalnum()) | |
| local_image_path = os.path.join(CACHE_DIR, f"{sanitized_name}.png") | |
| if not os.path.exists(local_image_path): | |
| try: | |
| print(f"Downloading image for '{compound_name}'...") | |
| url_safe_name = urllib.parse.quote(compound_name) | |
| image_url = f"https://cactus.nci.nih.gov/chemical/structure/{url_safe_name}/image" | |
| response = requests.get(image_url, timeout=20, headers=headers) | |
| if response.status_code == 200 and 'image' in response.headers.get('Content-Type', ''): | |
| image = Image.open(BytesIO(response.content)) | |
| image.save(local_image_path) | |
| print(f" ✅ Success: Saved '{compound_name}' to cache.") | |
| else: | |
| print(f" ❌ Failed for '{compound_name}': Server returned non-image content.") | |
| except Exception as e: | |
| print(f" ❌ Failed for '{compound_name}': {e}") | |
| else: | |
| print(f"Image for '{compound_name}' already in cache. Skipping.") | |
| print("--- Image Pre-Fetching Complete ---\n") | |
| # Run the pre-fetching function immediately | |
| if MODELS_LOADED: | |
| pre_fetch_images() | |
| # --- 4. Helper and Prediction Functions --- | |
| def create_error_image(message): | |
| img = Image.new('RGB', (400, 300), color=(255, 255, 255)) | |
| d = ImageDraw.Draw(img) | |
| try: # Use a common font | |
| font = ImageFont.truetype("DejaVuSans.ttf", 15) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| d.text((10,10), message, fill=(200,0,0), font=font) | |
| return img | |
| def get_compound_structure_image_from_cache(compound_name): | |
| sanitized_name = "".join(c for c in compound_name if c.isalnum()) | |
| local_image_path = os.path.join(CACHE_DIR, f"{sanitized_name}.png") | |
| if os.path.exists(local_image_path): | |
| return Image.open(local_image_path) | |
| else: | |
| return create_error_image(f"Image for '{compound_name}'\nwas not found in local cache.") | |
| def master_predict(model_choice, mol_weight, logp, hba, hbd, tpsa, rf_m, sc, l_enc, keras_m, rl_a): | |
| features = ['Molecular Weight', 'LogP', 'HBA', 'HBD', 'TPSA'] | |
| input_df = pd.DataFrame([{'Molecular Weight': mol_weight, 'LogP': logp, 'HBA': hba, 'HBD': hbd, 'TPSA': tpsa}], columns=features) | |
| s_scaled = sc.transform(input_df) | |
| pred_rf_idx = rf_m.predict(s_scaled)[0] | |
| pred_rf_class = l_enc.inverse_transform([pred_rf_idx])[0] | |
| if model_choice == "Normal MLP": | |
| pred_prob = keras_m.predict(s_scaled, verbose=0)[0] | |
| else: | |
| with torch.no_grad(): | |
| s_torch = torch.tensor(s_scaled, dtype=torch.float32) | |
| pred_prob = rl_a(s_torch).numpy()[0] | |
| top3_idx = np.argsort(pred_prob)[-3:][::-1] | |
| top3_predictions_data = [] | |
| for i in top3_idx: | |
| protein_name = l_enc.inverse_transform([i])[0] | |
| probability = float(pred_prob[i]) | |
| prob_percent = f"{probability:.2%}" | |
| top3_predictions_data.append([protein_name, probability, prob_percent]) | |
| mlp_results_df = pd.DataFrame(top3_predictions_data, columns=["Protein", "Probability", "Probability %"]) | |
| top_protein_name = mlp_results_df.iloc[0]["Protein"] | |
| compound_to_search = PROTEIN_TO_COMPOUND_MAP.get(top_protein_name, top_protein_name) | |
| structure_image = get_compound_structure_image_from_cache(compound_to_search) | |
| return pred_rf_class, mlp_results_df, structure_image | |
| # --- 5. Build the Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
| gr.Markdown("# Drug Discovery: Protein Target Predictor") | |
| if not MODELS_LOADED: | |
| gr.Markdown("## ERROR: MODELS FAILED TO LOAD. PLEASE CHECK THE REPOSITORY AND LOGS.") | |
| else: | |
| predict_with_models = partial(master_predict, rf_m=rf_model, sc=scaler, l_enc=le, keras_m=keras_model, rl_a=rl_agent) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Choose Your Model") | |
| model_choice = gr.Radio(["Normal MLP", "RL Upgraded MLP"], label="Select an MLP Model", value="Normal MLP") | |
| gr.Markdown("### 2. Input Molecular Properties") | |
| mw_slider = gr.Slider(100, 1000, value=350, step=1, label="Molecular Weight (g/mol)") | |
| logp_slider = gr.Slider(-5, 10, value=2.5, step=0.1, label="LogP (Lipophilicity)") | |
| hba_slider = gr.Slider(0, 20, value=4, step=1, label="HBA (Hydrogen Bond Acceptors)") | |
| hbd_slider = gr.Slider(0, 20, value=2, step=1, label="HBD (Hydrogen Bond Donors)") | |
| tpsa_slider = gr.Slider(0, 300, value=60, step=1, label="TPSA (Topological Polar Surface Area Ų)") | |
| submit_btn = gr.Button("Predict Target Protein", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 3. Prediction Results") | |
| out_rf = gr.Textbox(label="Random Forest Prediction (Most Likely Target)") | |
| out_mlp = gr.DataFrame(headers=["Protein", "Probability", "Probability %"], label="Top 3 MLP Predictions", datatype=["str", "number", "str"]) | |
| out_image = gr.Image(label="2D Structure of an Associated Compound", type="pil") | |
| submit_btn.click( | |
| fn=predict_with_models, | |
| inputs=[model_choice, mw_slider, logp_slider, hba_slider, hbd_slider, tpsa_slider], | |
| outputs=[out_rf, out_mlp, out_image] | |
| ) | |
| # Launch the application | |
| iface.launch() |