ACA050 commited on
Commit
45fff2c
·
verified ·
1 Parent(s): 8293825

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Drug Discovery Predictor - Deployment Ready for Hugging Face Spaces
4
+ """
5
+
6
+ import os
7
+ import gradio as gr
8
+ import joblib
9
+ import pandas as pd
10
+ import numpy as np
11
+ import tensorflow as tf
12
+ from tensorflow import keras
13
+ import torch
14
+ import torch.nn as nn
15
+ from PIL import Image, ImageDraw, ImageFont
16
+ import requests
17
+ from io import BytesIO
18
+ from functools import partial
19
+ import urllib.parse
20
+
21
+ # --- Global Setup ---
22
+ # Caching Setup for downloaded images
23
+ CACHE_DIR = "image_cache"
24
+ os.makedirs(CACHE_DIR, exist_ok=True)
25
+ print(f"Image cache directory created at: {CACHE_DIR}")
26
+
27
+ # --- 1. Define the PyTorch Model Class ---
28
+ # This class definition is required to load the saved PyTorch model state.
29
+ class MLPAgent(nn.Module):
30
+ def __init__(self, input_dim, num_classes):
31
+ super(MLPAgent, self).__init__()
32
+ self.net = nn.Sequential(
33
+ nn.Linear(input_dim, 128), nn.ReLU(), nn.Dropout(0.3),
34
+ nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.2),
35
+ nn.Linear(64, num_classes), nn.Softmax(dim=1)
36
+ )
37
+ def forward(self, x):
38
+ return self.net(x)
39
+
40
+ # --- 2. Load all models and preprocessors ---
41
+ # This section runs once when the application starts.
42
+ try:
43
+ rf_model = joblib.load("models/rf_model.joblib")
44
+ scaler = joblib.load("models/scaler.joblib")
45
+ le = joblib.load("models/le.joblib")
46
+ keras_model = keras.models.load_model("models/keras_mlp.h5")
47
+ num_classes = len(le.classes_)
48
+ rl_agent = MLPAgent(input_dim=5, num_classes=num_classes)
49
+ rl_agent.load_state_dict(torch.load("models/rl_upgraded_agent.pth"))
50
+ rl_agent.eval()
51
+ print("All models and preprocessors loaded successfully.")
52
+ MODELS_LOADED = True
53
+ except Exception as e:
54
+ print(f"Error loading models: {e}")
55
+ MODELS_LOADED = False
56
+
57
+ # Mapping from protein targets to known chemical compounds for image search
58
+ PROTEIN_TO_COMPOUND_MAP = {
59
+ "BACE1": "Verubecestat", "HDAC1": "Vorinostat", "EGFR": "Gefitinib",
60
+ "DRD2": "Haloperidol", "HIV-1 RT": "Nevirapine", "AMPC": "Cefoxitin",
61
+ "MMP-13": "Marimastat"
62
+ }
63
+
64
+ # --- 3. Image Pre-Fetching Logic ---
65
+ # To ensure fast performance, we download all images when the app starts.
66
+ def pre_fetch_images():
67
+ print("\n--- Starting Image Pre-Fetching ---")
68
+ headers = {
69
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
70
+ }
71
+ for compound_name in set(PROTEIN_TO_COMPOUND_MAP.values()):
72
+ sanitized_name = "".join(c for c in compound_name if c.isalnum())
73
+ local_image_path = os.path.join(CACHE_DIR, f"{sanitized_name}.png")
74
+
75
+ if not os.path.exists(local_image_path):
76
+ try:
77
+ print(f"Downloading image for '{compound_name}'...")
78
+ url_safe_name = urllib.parse.quote(compound_name)
79
+ image_url = f"https://cactus.nci.nih.gov/chemical/structure/{url_safe_name}/image"
80
+
81
+ response = requests.get(image_url, timeout=20, headers=headers)
82
+
83
+ if response.status_code == 200 and 'image' in response.headers.get('Content-Type', ''):
84
+ image = Image.open(BytesIO(response.content))
85
+ image.save(local_image_path)
86
+ print(f" ✅ Success: Saved '{compound_name}' to cache.")
87
+ else:
88
+ print(f" ❌ Failed for '{compound_name}': Server returned non-image content.")
89
+ except Exception as e:
90
+ print(f" ❌ Failed for '{compound_name}': {e}")
91
+ else:
92
+ print(f"Image for '{compound_name}' already in cache. Skipping.")
93
+ print("--- Image Pre-Fetching Complete ---\n")
94
+
95
+ # Run the pre-fetching function immediately
96
+ if MODELS_LOADED:
97
+ pre_fetch_images()
98
+
99
+ # --- 4. Helper and Prediction Functions ---
100
+ def create_error_image(message):
101
+ img = Image.new('RGB', (400, 300), color=(255, 255, 255))
102
+ d = ImageDraw.Draw(img)
103
+ try: # Use a common font
104
+ font = ImageFont.truetype("DejaVuSans.ttf", 15)
105
+ except IOError:
106
+ font = ImageFont.load_default()
107
+ d.text((10,10), message, fill=(200,0,0), font=font)
108
+ return img
109
+
110
+ def get_compound_structure_image_from_cache(compound_name):
111
+ sanitized_name = "".join(c for c in compound_name if c.isalnum())
112
+ local_image_path = os.path.join(CACHE_DIR, f"{sanitized_name}.png")
113
+
114
+ if os.path.exists(local_image_path):
115
+ return Image.open(local_image_path)
116
+ else:
117
+ return create_error_image(f"Image for '{compound_name}'\nwas not found in local cache.")
118
+
119
+ def master_predict(model_choice, mol_weight, logp, hba, hbd, tpsa, rf_m, sc, l_enc, keras_m, rl_a):
120
+ features = ['Molecular Weight', 'LogP', 'HBA', 'HBD', 'TPSA']
121
+ input_df = pd.DataFrame([{'Molecular Weight': mol_weight, 'LogP': logp, 'HBA': hba, 'HBD': hbd, 'TPSA': tpsa}], columns=features)
122
+ s_scaled = sc.transform(input_df)
123
+
124
+ pred_rf_idx = rf_m.predict(s_scaled)[0]
125
+ pred_rf_class = l_enc.inverse_transform([pred_rf_idx])[0]
126
+
127
+ if model_choice == "Normal MLP":
128
+ pred_prob = keras_m.predict(s_scaled, verbose=0)[0]
129
+ else:
130
+ with torch.no_grad():
131
+ s_torch = torch.tensor(s_scaled, dtype=torch.float32)
132
+ pred_prob = rl_a(s_torch).numpy()[0]
133
+
134
+ top3_idx = np.argsort(pred_prob)[-3:][::-1]
135
+ top3_predictions_data = []
136
+ for i in top3_idx:
137
+ protein_name = l_enc.inverse_transform([i])[0]
138
+ probability = float(pred_prob[i])
139
+ prob_percent = f"{probability:.2%}"
140
+ top3_predictions_data.append([protein_name, probability, prob_percent])
141
+
142
+ mlp_results_df = pd.DataFrame(top3_predictions_data, columns=["Protein", "Probability", "Probability %"])
143
+
144
+ top_protein_name = mlp_results_df.iloc[0]["Protein"]
145
+ compound_to_search = PROTEIN_TO_COMPOUND_MAP.get(top_protein_name, top_protein_name)
146
+ structure_image = get_compound_structure_image_from_cache(compound_to_search)
147
+
148
+ return pred_rf_class, mlp_results_df, structure_image
149
+
150
+ # --- 5. Build the Gradio Interface ---
151
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
152
+ gr.Markdown("# Drug Discovery: Protein Target Predictor")
153
+ if not MODELS_LOADED:
154
+ gr.Markdown("## ERROR: MODELS FAILED TO LOAD. PLEASE CHECK THE REPOSITORY AND LOGS.")
155
+ else:
156
+ predict_with_models = partial(master_predict, rf_m=rf_model, sc=scaler, l_enc=le, keras_m=keras_model, rl_a=rl_agent)
157
+ with gr.Row():
158
+ with gr.Column(scale=1):
159
+ gr.Markdown("### 1. Choose Your Model")
160
+ model_choice = gr.Radio(["Normal MLP", "RL Upgraded MLP"], label="Select an MLP Model", value="Normal MLP")
161
+ gr.Markdown("### 2. Input Molecular Properties")
162
+ mw_slider = gr.Slider(100, 1000, value=350, step=1, label="Molecular Weight (g/mol)")
163
+ logp_slider = gr.Slider(-5, 10, value=2.5, step=0.1, label="LogP (Lipophilicity)")
164
+ hba_slider = gr.Slider(0, 20, value=4, step=1, label="HBA (Hydrogen Bond Acceptors)")
165
+ hbd_slider = gr.Slider(0, 20, value=2, step=1, label="HBD (Hydrogen Bond Donors)")
166
+ tpsa_slider = gr.Slider(0, 300, value=60, step=1, label="TPSA (Topological Polar Surface Area Ų)")
167
+ submit_btn = gr.Button("Predict Target Protein", variant="primary")
168
+ with gr.Column(scale=2):
169
+ gr.Markdown("### 3. Prediction Results")
170
+ out_rf = gr.Textbox(label="Random Forest Prediction (Most Likely Target)")
171
+ out_mlp = gr.DataFrame(headers=["Protein", "Probability", "Probability %"], label="Top 3 MLP Predictions", datatype=["str", "number", "str"])
172
+ out_image = gr.Image(label="2D Structure of an Associated Compound", type="pil")
173
+
174
+ submit_btn.click(
175
+ fn=predict_with_models,
176
+ inputs=[model_choice, mw_slider, logp_slider, hba_slider, hbd_slider, tpsa_slider],
177
+ outputs=[out_rf, out_mlp, out_image]
178
+ )
179
+
180
+ # Launch the application
181
+ iface.launch()