root commited on
Commit
5e4e50b
·
1 Parent(s): 0084cd1

Add Gradio UI with single and batch prediction

Browse files
Files changed (2) hide show
  1. app.py +251 -6
  2. requirements.txt +2 -0
app.py CHANGED
@@ -5,23 +5,33 @@ from typing import List, Dict, Any
5
  ROOT = Path(__file__).parent
6
  sys.path.insert(0, str(ROOT))
7
 
 
8
  from fastapi import FastAPI, HTTPException
9
  from pydantic import BaseModel, Field
10
  import numpy as np
11
 
12
  from src import EnhancedFeatureExtractor, Tox21Ensemble
13
 
14
- app = FastAPI(
15
- title="Rasayan Tox21 Classifier",
16
- description="Self-Normalizing Neural Network ensemble for Tox21 toxicity prediction",
17
- version="1.0.0"
18
- )
19
-
20
  TASKS = [
21
  "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD",
22
  "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
23
  ]
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  FEATURE_KEYS = [
26
  "ecfps", "maccs", "rdkit_descrs", "tox", "rdkit_filters",
27
  "similarity", "max_similarity", "db_similarity"
@@ -38,6 +48,239 @@ ensemble = Tox21Ensemble(ROOT / "checkpoints" / "ensemble.pt")
38
  print("Model loaded successfully!")
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  class PredictRequest(BaseModel):
42
  smiles: List[str] = Field(..., min_length=1, max_length=1000)
43
 
@@ -113,6 +356,8 @@ def health():
113
  return {"status": "ok"}
114
 
115
 
 
 
116
  if __name__ == "__main__":
117
  import uvicorn
118
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
5
  ROOT = Path(__file__).parent
6
  sys.path.insert(0, str(ROOT))
7
 
8
+ import gradio as gr
9
  from fastapi import FastAPI, HTTPException
10
  from pydantic import BaseModel, Field
11
  import numpy as np
12
 
13
  from src import EnhancedFeatureExtractor, Tox21Ensemble
14
 
 
 
 
 
 
 
15
  TASKS = [
16
  "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD",
17
  "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
18
  ]
19
 
20
+ TASK_DESCRIPTIONS = {
21
+ "NR-AR": "Androgen Receptor",
22
+ "NR-AR-LBD": "Androgen Receptor LBD",
23
+ "NR-AhR": "Aryl Hydrocarbon Receptor",
24
+ "NR-Aromatase": "Aromatase (CYP19A1)",
25
+ "NR-ER": "Estrogen Receptor",
26
+ "NR-ER-LBD": "Estrogen Receptor LBD",
27
+ "NR-PPAR-gamma": "PPARγ",
28
+ "SR-ARE": "Antioxidant Response",
29
+ "SR-ATAD5": "DNA Damage (ATAD5)",
30
+ "SR-HSE": "Heat Shock Response",
31
+ "SR-MMP": "Mitochondrial Toxicity",
32
+ "SR-p53": "Genotoxicity (p53)"
33
+ }
34
+
35
  FEATURE_KEYS = [
36
  "ecfps", "maccs", "rdkit_descrs", "tox", "rdkit_filters",
37
  "similarity", "max_similarity", "db_similarity"
 
48
  print("Model loaded successfully!")
49
 
50
 
51
+ def predict_toxicity(smiles_input: str) -> tuple:
52
+ if not smiles_input.strip():
53
+ return None, "Please enter at least one SMILES"
54
+
55
+ lines = [s.strip() for s in smiles_input.strip().split('\n') if s.strip()]
56
+
57
+ if len(lines) > 100:
58
+ return None, "Maximum 100 molecules per request"
59
+
60
+ try:
61
+ features_dict, valid = extractor.extract_features(lines)
62
+
63
+ features = np.concatenate(
64
+ [features_dict[k] for k in FEATURE_KEYS if k in features_dict],
65
+ axis=1
66
+ )
67
+ features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
68
+
69
+ probs = ensemble.predict(features)
70
+
71
+ results = []
72
+ for i, smi in enumerate(lines):
73
+ if valid[i]:
74
+ row = {"SMILES": smi[:50] + "..." if len(smi) > 50 else smi}
75
+ for j, task in enumerate(TASKS):
76
+ score = float(probs[i, j])
77
+ row[task] = f"{score:.1%}"
78
+ results.append(row)
79
+ else:
80
+ row = {"SMILES": smi[:50] + "..." if len(smi) > 50 else smi}
81
+ for task in TASKS:
82
+ row[task] = "Invalid"
83
+ results.append(row)
84
+
85
+ import pandas as pd
86
+ df = pd.DataFrame(results)
87
+
88
+ return df, f"Processed {len(lines)} molecule(s)"
89
+
90
+ except Exception as e:
91
+ return None, f"Error: {str(e)}"
92
+
93
+
94
+ def predict_single(smiles: str) -> str:
95
+ if not smiles.strip():
96
+ return "Enter a SMILES string"
97
+
98
+ try:
99
+ features_dict, valid = extractor.extract_features([smiles])
100
+
101
+ if not valid[0]:
102
+ return "Invalid SMILES structure"
103
+
104
+ features = np.concatenate(
105
+ [features_dict[k] for k in FEATURE_KEYS if k in features_dict],
106
+ axis=1
107
+ )
108
+ features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
109
+
110
+ probs = ensemble.predict(features)
111
+
112
+ lines = []
113
+ lines.append("═" * 45)
114
+ lines.append(" TOXICITY PREDICTION RESULTS")
115
+ lines.append("═" * 45)
116
+
117
+ sorted_results = sorted(
118
+ [(task, float(probs[0, j])) for j, task in enumerate(TASKS)],
119
+ key=lambda x: -x[1]
120
+ )
121
+
122
+ for task, score in sorted_results:
123
+ desc = TASK_DESCRIPTIONS[task]
124
+ bar_len = int(score * 20)
125
+ bar = "█" * bar_len + "░" * (20 - bar_len)
126
+
127
+ if score >= 0.7:
128
+ risk = "HIGH"
129
+ elif score >= 0.4:
130
+ risk = "MED "
131
+ elif score >= 0.2:
132
+ risk = "LOW "
133
+ else:
134
+ risk = "MIN "
135
+
136
+ lines.append(f"{task:15} {bar} {score:5.1%} [{risk}]")
137
+ lines.append(f" └─ {desc}")
138
+
139
+ lines.append("═" * 45)
140
+
141
+ return "\n".join(lines)
142
+
143
+ except Exception as e:
144
+ return f"Error: {str(e)}"
145
+
146
+
147
+ EXAMPLES = [
148
+ ["CCO"],
149
+ ["CC(=O)Nc1ccc(O)cc1"],
150
+ ["c1ccc2c(c1)cc3ccc4cccc5ccc2c3c45"],
151
+ ["CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C"],
152
+ ["CC12CCC3c4ccc(O)cc4CCC3C1CCC2O"],
153
+ ]
154
+
155
+ with gr.Blocks(
156
+ title="Rasayan Tox21 Classifier",
157
+ theme=gr.themes.Soft()
158
+ ) as demo:
159
+ gr.Markdown("""
160
+ # ☠️ Rasayan Tox21 Classifier
161
+
162
+ Predict molecular toxicity across **12 Tox21 endpoints** using a Self-Normalizing Neural Network ensemble.
163
+
164
+ | Model | Features | Training |
165
+ |-------|----------|----------|
166
+ | 10-fold SNN Ensemble | 11,377 molecular descriptors | 40-fold CV, AUC: 0.882 |
167
+ """)
168
+
169
+ with gr.Tabs():
170
+ with gr.TabItem("Single Molecule"):
171
+ with gr.Row():
172
+ with gr.Column(scale=1):
173
+ single_input = gr.Textbox(
174
+ label="SMILES",
175
+ placeholder="Enter SMILES (e.g., CCO for ethanol)",
176
+ lines=1
177
+ )
178
+ single_btn = gr.Button("Predict", variant="primary")
179
+ gr.Examples(
180
+ examples=EXAMPLES,
181
+ inputs=single_input,
182
+ label="Example Molecules"
183
+ )
184
+
185
+ with gr.Column(scale=2):
186
+ single_output = gr.Textbox(
187
+ label="Toxicity Profile",
188
+ lines=30,
189
+ show_copy_button=True
190
+ )
191
+
192
+ single_btn.click(
193
+ fn=predict_single,
194
+ inputs=single_input,
195
+ outputs=single_output
196
+ )
197
+
198
+ with gr.TabItem("Batch Processing"):
199
+ gr.Markdown("Enter multiple SMILES (one per line, max 100)")
200
+
201
+ batch_input = gr.Textbox(
202
+ label="SMILES List",
203
+ placeholder="CCO\nCC(=O)Nc1ccc(O)cc1\nc1ccccc1",
204
+ lines=5
205
+ )
206
+ batch_btn = gr.Button("Process Batch", variant="primary")
207
+ batch_status = gr.Textbox(label="Status", lines=1)
208
+ batch_output = gr.Dataframe(
209
+ label="Results",
210
+ wrap=True
211
+ )
212
+
213
+ batch_btn.click(
214
+ fn=predict_toxicity,
215
+ inputs=batch_input,
216
+ outputs=[batch_output, batch_status]
217
+ )
218
+
219
+ with gr.TabItem("About"):
220
+ gr.Markdown("""
221
+ ## Model Architecture
222
+
223
+ **Self-Normalizing Neural Networks (SNNs)** with SELU activation and AlphaDropout.
224
+
225
+ | Component | Details |
226
+ |-----------|---------|
227
+ | Hidden Layers | 8 × 768 units |
228
+ | Activation | SELU |
229
+ | Dropout | AlphaDropout (0.1) |
230
+ | Ensemble | Top-10 from 40-fold CV |
231
+ | Parameters | ~19M total |
232
+
233
+ ## Molecular Features (11,377 total)
234
+
235
+ | Feature | Dimensions | Description |
236
+ |---------|------------|-------------|
237
+ | ECFP6 | 8,192 | Morgan fingerprints (radius 3) |
238
+ | MACCS | 167 | Structural keys |
239
+ | RDKit | 208 | Physicochemical descriptors |
240
+ | Toxicophores | 1,868 | Toxicity structural alerts |
241
+ | Filters | 815 | PAINS, BRENK, NIH, ZINC |
242
+ | Similarity | 127 | Target ligand similarity |
243
+
244
+ ## Tox21 Endpoints
245
+
246
+ ### Nuclear Receptor Panel
247
+ - **NR-AR**: Androgen Receptor
248
+ - **NR-AR-LBD**: AR Ligand Binding Domain
249
+ - **NR-AhR**: Aryl Hydrocarbon Receptor
250
+ - **NR-Aromatase**: CYP19A1 Enzyme
251
+ - **NR-ER**: Estrogen Receptor
252
+ - **NR-ER-LBD**: ER Ligand Binding Domain
253
+ - **NR-PPAR-gamma**: Peroxisome Proliferator-Activated Receptor
254
+
255
+ ### Stress Response Panel
256
+ - **SR-ARE**: Antioxidant Response Element
257
+ - **SR-ATAD5**: DNA Damage Response
258
+ - **SR-HSE**: Heat Shock Element
259
+ - **SR-MMP**: Mitochondrial Membrane Potential
260
+ - **SR-p53**: Tumor Suppressor p53
261
+
262
+ ## Risk Interpretation
263
+
264
+ | Score | Risk Level |
265
+ |-------|------------|
266
+ | < 20% | Minimal |
267
+ | 20-40% | Low |
268
+ | 40-70% | Moderate |
269
+ | ≥ 70% | High |
270
+
271
+ ---
272
+
273
+ Built by [Rasayan Labs](https://rasayan.ai)
274
+ """)
275
+
276
+ gr.Markdown("""
277
+ ---
278
+ **API Endpoints**: `/predict` (POST), `/metadata` (GET), `/health` (GET)
279
+ """)
280
+
281
+ app = FastAPI()
282
+
283
+
284
  class PredictRequest(BaseModel):
285
  smiles: List[str] = Field(..., min_length=1, max_length=1000)
286
 
 
356
  return {"status": "ok"}
357
 
358
 
359
+ app = gr.mount_gradio_app(app, demo, path="/")
360
+
361
  if __name__ == "__main__":
362
  import uvicorn
363
  uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -5,3 +5,5 @@ numpy>=1.24.0
5
  torch>=2.0.0
6
  rdkit>=2023.3.1
7
  scikit-learn>=1.3.0
 
 
 
5
  torch>=2.0.0
6
  rdkit>=2023.3.1
7
  scikit-learn>=1.3.0
8
+ gradio>=4.0.0
9
+ pandas>=2.0.0