ACA050's picture
Create app.py
45fff2c verified
# -*- coding: utf-8 -*-
"""
Drug Discovery Predictor - Deployment Ready for Hugging Face Spaces
"""
import os
import gradio as gr
import joblib
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
import torch
import torch.nn as nn
from PIL import Image, ImageDraw, ImageFont
import requests
from io import BytesIO
from functools import partial
import urllib.parse
# --- Global Setup ---
# Caching Setup for downloaded images
CACHE_DIR = "image_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
print(f"Image cache directory created at: {CACHE_DIR}")
# --- 1. Define the PyTorch Model Class ---
# This class definition is required to load the saved PyTorch model state.
class MLPAgent(nn.Module):
def __init__(self, input_dim, num_classes):
super(MLPAgent, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 128), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(64, num_classes), nn.Softmax(dim=1)
)
def forward(self, x):
return self.net(x)
# --- 2. Load all models and preprocessors ---
# This section runs once when the application starts.
try:
rf_model = joblib.load("models/rf_model.joblib")
scaler = joblib.load("models/scaler.joblib")
le = joblib.load("models/le.joblib")
keras_model = keras.models.load_model("models/keras_mlp.h5")
num_classes = len(le.classes_)
rl_agent = MLPAgent(input_dim=5, num_classes=num_classes)
rl_agent.load_state_dict(torch.load("models/rl_upgraded_agent.pth"))
rl_agent.eval()
print("All models and preprocessors loaded successfully.")
MODELS_LOADED = True
except Exception as e:
print(f"Error loading models: {e}")
MODELS_LOADED = False
# Mapping from protein targets to known chemical compounds for image search
PROTEIN_TO_COMPOUND_MAP = {
"BACE1": "Verubecestat", "HDAC1": "Vorinostat", "EGFR": "Gefitinib",
"DRD2": "Haloperidol", "HIV-1 RT": "Nevirapine", "AMPC": "Cefoxitin",
"MMP-13": "Marimastat"
}
# --- 3. Image Pre-Fetching Logic ---
# To ensure fast performance, we download all images when the app starts.
def pre_fetch_images():
print("\n--- Starting Image Pre-Fetching ---")
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
for compound_name in set(PROTEIN_TO_COMPOUND_MAP.values()):
sanitized_name = "".join(c for c in compound_name if c.isalnum())
local_image_path = os.path.join(CACHE_DIR, f"{sanitized_name}.png")
if not os.path.exists(local_image_path):
try:
print(f"Downloading image for '{compound_name}'...")
url_safe_name = urllib.parse.quote(compound_name)
image_url = f"https://cactus.nci.nih.gov/chemical/structure/{url_safe_name}/image"
response = requests.get(image_url, timeout=20, headers=headers)
if response.status_code == 200 and 'image' in response.headers.get('Content-Type', ''):
image = Image.open(BytesIO(response.content))
image.save(local_image_path)
print(f" ✅ Success: Saved '{compound_name}' to cache.")
else:
print(f" ❌ Failed for '{compound_name}': Server returned non-image content.")
except Exception as e:
print(f" ❌ Failed for '{compound_name}': {e}")
else:
print(f"Image for '{compound_name}' already in cache. Skipping.")
print("--- Image Pre-Fetching Complete ---\n")
# Run the pre-fetching function immediately
if MODELS_LOADED:
pre_fetch_images()
# --- 4. Helper and Prediction Functions ---
def create_error_image(message):
img = Image.new('RGB', (400, 300), color=(255, 255, 255))
d = ImageDraw.Draw(img)
try: # Use a common font
font = ImageFont.truetype("DejaVuSans.ttf", 15)
except IOError:
font = ImageFont.load_default()
d.text((10,10), message, fill=(200,0,0), font=font)
return img
def get_compound_structure_image_from_cache(compound_name):
sanitized_name = "".join(c for c in compound_name if c.isalnum())
local_image_path = os.path.join(CACHE_DIR, f"{sanitized_name}.png")
if os.path.exists(local_image_path):
return Image.open(local_image_path)
else:
return create_error_image(f"Image for '{compound_name}'\nwas not found in local cache.")
def master_predict(model_choice, mol_weight, logp, hba, hbd, tpsa, rf_m, sc, l_enc, keras_m, rl_a):
features = ['Molecular Weight', 'LogP', 'HBA', 'HBD', 'TPSA']
input_df = pd.DataFrame([{'Molecular Weight': mol_weight, 'LogP': logp, 'HBA': hba, 'HBD': hbd, 'TPSA': tpsa}], columns=features)
s_scaled = sc.transform(input_df)
pred_rf_idx = rf_m.predict(s_scaled)[0]
pred_rf_class = l_enc.inverse_transform([pred_rf_idx])[0]
if model_choice == "Normal MLP":
pred_prob = keras_m.predict(s_scaled, verbose=0)[0]
else:
with torch.no_grad():
s_torch = torch.tensor(s_scaled, dtype=torch.float32)
pred_prob = rl_a(s_torch).numpy()[0]
top3_idx = np.argsort(pred_prob)[-3:][::-1]
top3_predictions_data = []
for i in top3_idx:
protein_name = l_enc.inverse_transform([i])[0]
probability = float(pred_prob[i])
prob_percent = f"{probability:.2%}"
top3_predictions_data.append([protein_name, probability, prob_percent])
mlp_results_df = pd.DataFrame(top3_predictions_data, columns=["Protein", "Probability", "Probability %"])
top_protein_name = mlp_results_df.iloc[0]["Protein"]
compound_to_search = PROTEIN_TO_COMPOUND_MAP.get(top_protein_name, top_protein_name)
structure_image = get_compound_structure_image_from_cache(compound_to_search)
return pred_rf_class, mlp_results_df, structure_image
# --- 5. Build the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown("# Drug Discovery: Protein Target Predictor")
if not MODELS_LOADED:
gr.Markdown("## ERROR: MODELS FAILED TO LOAD. PLEASE CHECK THE REPOSITORY AND LOGS.")
else:
predict_with_models = partial(master_predict, rf_m=rf_model, sc=scaler, l_enc=le, keras_m=keras_model, rl_a=rl_agent)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Choose Your Model")
model_choice = gr.Radio(["Normal MLP", "RL Upgraded MLP"], label="Select an MLP Model", value="Normal MLP")
gr.Markdown("### 2. Input Molecular Properties")
mw_slider = gr.Slider(100, 1000, value=350, step=1, label="Molecular Weight (g/mol)")
logp_slider = gr.Slider(-5, 10, value=2.5, step=0.1, label="LogP (Lipophilicity)")
hba_slider = gr.Slider(0, 20, value=4, step=1, label="HBA (Hydrogen Bond Acceptors)")
hbd_slider = gr.Slider(0, 20, value=2, step=1, label="HBD (Hydrogen Bond Donors)")
tpsa_slider = gr.Slider(0, 300, value=60, step=1, label="TPSA (Topological Polar Surface Area Ų)")
submit_btn = gr.Button("Predict Target Protein", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### 3. Prediction Results")
out_rf = gr.Textbox(label="Random Forest Prediction (Most Likely Target)")
out_mlp = gr.DataFrame(headers=["Protein", "Probability", "Probability %"], label="Top 3 MLP Predictions", datatype=["str", "number", "str"])
out_image = gr.Image(label="2D Structure of an Associated Compound", type="pil")
submit_btn.click(
fn=predict_with_models,
inputs=[model_choice, mw_slider, logp_slider, hba_slider, hbd_slider, tpsa_slider],
outputs=[out_rf, out_mlp, out_image]
)
# Launch the application
iface.launch()