# -*- coding: utf-8 -*- """ Drug Discovery Predictor - Deployment Ready for Hugging Face Spaces """ import os import gradio as gr import joblib import pandas as pd import numpy as np import tensorflow as tf from tensorflow import keras import torch import torch.nn as nn from PIL import Image, ImageDraw, ImageFont import requests from io import BytesIO from functools import partial import urllib.parse # --- Global Setup --- # Caching Setup for downloaded images CACHE_DIR = "image_cache" os.makedirs(CACHE_DIR, exist_ok=True) print(f"Image cache directory created at: {CACHE_DIR}") # --- 1. Define the PyTorch Model Class --- # This class definition is required to load the saved PyTorch model state. class MLPAgent(nn.Module): def __init__(self, input_dim, num_classes): super(MLPAgent, self).__init__() self.net = nn.Sequential( nn.Linear(input_dim, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_classes), nn.Softmax(dim=1) ) def forward(self, x): return self.net(x) # --- 2. Load all models and preprocessors --- # This section runs once when the application starts. try: rf_model = joblib.load("models/rf_model.joblib") scaler = joblib.load("models/scaler.joblib") le = joblib.load("models/le.joblib") keras_model = keras.models.load_model("models/keras_mlp.h5") num_classes = len(le.classes_) rl_agent = MLPAgent(input_dim=5, num_classes=num_classes) rl_agent.load_state_dict(torch.load("models/rl_upgraded_agent.pth")) rl_agent.eval() print("All models and preprocessors loaded successfully.") MODELS_LOADED = True except Exception as e: print(f"Error loading models: {e}") MODELS_LOADED = False # Mapping from protein targets to known chemical compounds for image search PROTEIN_TO_COMPOUND_MAP = { "BACE1": "Verubecestat", "HDAC1": "Vorinostat", "EGFR": "Gefitinib", "DRD2": "Haloperidol", "HIV-1 RT": "Nevirapine", "AMPC": "Cefoxitin", "MMP-13": "Marimastat" } # --- 3. Image Pre-Fetching Logic --- # To ensure fast performance, we download all images when the app starts. def pre_fetch_images(): print("\n--- Starting Image Pre-Fetching ---") headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' } for compound_name in set(PROTEIN_TO_COMPOUND_MAP.values()): sanitized_name = "".join(c for c in compound_name if c.isalnum()) local_image_path = os.path.join(CACHE_DIR, f"{sanitized_name}.png") if not os.path.exists(local_image_path): try: print(f"Downloading image for '{compound_name}'...") url_safe_name = urllib.parse.quote(compound_name) image_url = f"https://cactus.nci.nih.gov/chemical/structure/{url_safe_name}/image" response = requests.get(image_url, timeout=20, headers=headers) if response.status_code == 200 and 'image' in response.headers.get('Content-Type', ''): image = Image.open(BytesIO(response.content)) image.save(local_image_path) print(f" ✅ Success: Saved '{compound_name}' to cache.") else: print(f" ❌ Failed for '{compound_name}': Server returned non-image content.") except Exception as e: print(f" ❌ Failed for '{compound_name}': {e}") else: print(f"Image for '{compound_name}' already in cache. Skipping.") print("--- Image Pre-Fetching Complete ---\n") # Run the pre-fetching function immediately if MODELS_LOADED: pre_fetch_images() # --- 4. Helper and Prediction Functions --- def create_error_image(message): img = Image.new('RGB', (400, 300), color=(255, 255, 255)) d = ImageDraw.Draw(img) try: # Use a common font font = ImageFont.truetype("DejaVuSans.ttf", 15) except IOError: font = ImageFont.load_default() d.text((10,10), message, fill=(200,0,0), font=font) return img def get_compound_structure_image_from_cache(compound_name): sanitized_name = "".join(c for c in compound_name if c.isalnum()) local_image_path = os.path.join(CACHE_DIR, f"{sanitized_name}.png") if os.path.exists(local_image_path): return Image.open(local_image_path) else: return create_error_image(f"Image for '{compound_name}'\nwas not found in local cache.") def master_predict(model_choice, mol_weight, logp, hba, hbd, tpsa, rf_m, sc, l_enc, keras_m, rl_a): features = ['Molecular Weight', 'LogP', 'HBA', 'HBD', 'TPSA'] input_df = pd.DataFrame([{'Molecular Weight': mol_weight, 'LogP': logp, 'HBA': hba, 'HBD': hbd, 'TPSA': tpsa}], columns=features) s_scaled = sc.transform(input_df) pred_rf_idx = rf_m.predict(s_scaled)[0] pred_rf_class = l_enc.inverse_transform([pred_rf_idx])[0] if model_choice == "Normal MLP": pred_prob = keras_m.predict(s_scaled, verbose=0)[0] else: with torch.no_grad(): s_torch = torch.tensor(s_scaled, dtype=torch.float32) pred_prob = rl_a(s_torch).numpy()[0] top3_idx = np.argsort(pred_prob)[-3:][::-1] top3_predictions_data = [] for i in top3_idx: protein_name = l_enc.inverse_transform([i])[0] probability = float(pred_prob[i]) prob_percent = f"{probability:.2%}" top3_predictions_data.append([protein_name, probability, prob_percent]) mlp_results_df = pd.DataFrame(top3_predictions_data, columns=["Protein", "Probability", "Probability %"]) top_protein_name = mlp_results_df.iloc[0]["Protein"] compound_to_search = PROTEIN_TO_COMPOUND_MAP.get(top_protein_name, top_protein_name) structure_image = get_compound_structure_image_from_cache(compound_to_search) return pred_rf_class, mlp_results_df, structure_image # --- 5. Build the Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft()) as iface: gr.Markdown("# Drug Discovery: Protein Target Predictor") if not MODELS_LOADED: gr.Markdown("## ERROR: MODELS FAILED TO LOAD. PLEASE CHECK THE REPOSITORY AND LOGS.") else: predict_with_models = partial(master_predict, rf_m=rf_model, sc=scaler, l_enc=le, keras_m=keras_model, rl_a=rl_agent) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Choose Your Model") model_choice = gr.Radio(["Normal MLP", "RL Upgraded MLP"], label="Select an MLP Model", value="Normal MLP") gr.Markdown("### 2. Input Molecular Properties") mw_slider = gr.Slider(100, 1000, value=350, step=1, label="Molecular Weight (g/mol)") logp_slider = gr.Slider(-5, 10, value=2.5, step=0.1, label="LogP (Lipophilicity)") hba_slider = gr.Slider(0, 20, value=4, step=1, label="HBA (Hydrogen Bond Acceptors)") hbd_slider = gr.Slider(0, 20, value=2, step=1, label="HBD (Hydrogen Bond Donors)") tpsa_slider = gr.Slider(0, 300, value=60, step=1, label="TPSA (Topological Polar Surface Area Ų)") submit_btn = gr.Button("Predict Target Protein", variant="primary") with gr.Column(scale=2): gr.Markdown("### 3. Prediction Results") out_rf = gr.Textbox(label="Random Forest Prediction (Most Likely Target)") out_mlp = gr.DataFrame(headers=["Protein", "Probability", "Probability %"], label="Top 3 MLP Predictions", datatype=["str", "number", "str"]) out_image = gr.Image(label="2D Structure of an Associated Compound", type="pil") submit_btn.click( fn=predict_with_models, inputs=[model_choice, mw_slider, logp_slider, hba_slider, hbd_slider, tpsa_slider], outputs=[out_rf, out_mlp, out_image] ) # Launch the application iface.launch()