rasayan-tox21 / app.py
root
Fix Gradio compatibility
1b20033
import sys
from pathlib import Path
from typing import List, Dict, Any
ROOT = Path(__file__).parent
sys.path.insert(0, str(ROOT))
import gradio as gr
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import numpy as np
from src import EnhancedFeatureExtractor, Tox21Ensemble
TASKS = [
"NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD",
"NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
]
TASK_DESCRIPTIONS = {
"NR-AR": "Androgen Receptor",
"NR-AR-LBD": "Androgen Receptor LBD",
"NR-AhR": "Aryl Hydrocarbon Receptor",
"NR-Aromatase": "Aromatase (CYP19A1)",
"NR-ER": "Estrogen Receptor",
"NR-ER-LBD": "Estrogen Receptor LBD",
"NR-PPAR-gamma": "PPARγ",
"SR-ARE": "Antioxidant Response",
"SR-ATAD5": "DNA Damage (ATAD5)",
"SR-HSE": "Heat Shock Response",
"SR-MMP": "Mitochondrial Toxicity",
"SR-p53": "Genotoxicity (p53)"
}
FEATURE_KEYS = [
"ecfps", "maccs", "rdkit_descrs", "tox", "rdkit_filters",
"similarity", "max_similarity", "db_similarity"
]
MAX_BATCH_SIZE = 256
print("Loading model...")
extractor = EnhancedFeatureExtractor(
toxicophores_path=ROOT / "data" / "toxicophores_validated.json",
db_ligands_path=ROOT / "data" / "target_ligands_validated.json",
)
ensemble = Tox21Ensemble(ROOT / "checkpoints" / "ensemble.pt")
print("Model loaded successfully!")
def predict_toxicity(smiles_input: str) -> tuple:
if not smiles_input.strip():
return None, "Please enter at least one SMILES"
lines = [s.strip() for s in smiles_input.strip().split('\n') if s.strip()]
if len(lines) > 100:
return None, "Maximum 100 molecules per request"
try:
features_dict, valid = extractor.extract_features(lines)
features = np.concatenate(
[features_dict[k] for k in FEATURE_KEYS if k in features_dict],
axis=1
)
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
probs = ensemble.predict(features)
results = []
for i, smi in enumerate(lines):
if valid[i]:
row = {"SMILES": smi[:50] + "..." if len(smi) > 50 else smi}
for j, task in enumerate(TASKS):
score = float(probs[i, j])
row[task] = f"{score:.1%}"
results.append(row)
else:
row = {"SMILES": smi[:50] + "..." if len(smi) > 50 else smi}
for task in TASKS:
row[task] = "Invalid"
results.append(row)
import pandas as pd
df = pd.DataFrame(results)
return df, f"Processed {len(lines)} molecule(s)"
except Exception as e:
return None, f"Error: {str(e)}"
def predict_single(smiles: str) -> str:
if not smiles.strip():
return "Enter a SMILES string"
try:
features_dict, valid = extractor.extract_features([smiles])
if not valid[0]:
return "Invalid SMILES structure"
features = np.concatenate(
[features_dict[k] for k in FEATURE_KEYS if k in features_dict],
axis=1
)
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
probs = ensemble.predict(features)
lines = []
lines.append("═" * 45)
lines.append(" TOXICITY PREDICTION RESULTS")
lines.append("═" * 45)
sorted_results = sorted(
[(task, float(probs[0, j])) for j, task in enumerate(TASKS)],
key=lambda x: -x[1]
)
for task, score in sorted_results:
desc = TASK_DESCRIPTIONS[task]
bar_len = int(score * 20)
bar = "█" * bar_len + "░" * (20 - bar_len)
if score >= 0.7:
risk = "HIGH"
elif score >= 0.4:
risk = "MED "
elif score >= 0.2:
risk = "LOW "
else:
risk = "MIN "
lines.append(f"{task:15} {bar} {score:5.1%} [{risk}]")
lines.append(f" └─ {desc}")
lines.append("═" * 45)
return "\n".join(lines)
except Exception as e:
return f"Error: {str(e)}"
EXAMPLES = [
["CCO"],
["CC(=O)Nc1ccc(O)cc1"],
["c1ccc2c(c1)cc3ccc4cccc5ccc2c3c45"],
["CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C"],
["CC12CCC3c4ccc(O)cc4CCC3C1CCC2O"],
]
with gr.Blocks(
title="Rasayan Tox21 Classifier",
theme=gr.themes.Soft()
) as demo:
gr.Markdown("""
# ☠️ Rasayan Tox21 Classifier
Predict molecular toxicity across **12 Tox21 endpoints** using a Self-Normalizing Neural Network ensemble.
| Model | Features | Training |
|-------|----------|----------|
| 10-fold SNN Ensemble | 11,377 molecular descriptors | 40-fold CV, AUC: 0.882 |
""")
with gr.Tabs():
with gr.TabItem("Single Molecule"):
with gr.Row():
with gr.Column(scale=1):
single_input = gr.Textbox(
label="SMILES",
placeholder="Enter SMILES (e.g., CCO for ethanol)",
lines=1
)
single_btn = gr.Button("Predict", variant="primary")
gr.Examples(
examples=EXAMPLES,
inputs=single_input,
label="Example Molecules"
)
with gr.Column(scale=2):
single_output = gr.Textbox(
label="Toxicity Profile",
lines=30
)
single_btn.click(
fn=predict_single,
inputs=single_input,
outputs=single_output
)
with gr.TabItem("Batch Processing"):
gr.Markdown("Enter multiple SMILES (one per line, max 100)")
batch_input = gr.Textbox(
label="SMILES List",
placeholder="CCO\nCC(=O)Nc1ccc(O)cc1\nc1ccccc1",
lines=5
)
batch_btn = gr.Button("Process Batch", variant="primary")
batch_status = gr.Textbox(label="Status", lines=1)
batch_output = gr.Dataframe(
label="Results",
wrap=True
)
batch_btn.click(
fn=predict_toxicity,
inputs=batch_input,
outputs=[batch_output, batch_status]
)
with gr.TabItem("About"):
gr.Markdown("""
## Model Architecture
**Self-Normalizing Neural Networks (SNNs)** with SELU activation and AlphaDropout.
| Component | Details |
|-----------|---------|
| Hidden Layers | 8 × 768 units |
| Activation | SELU |
| Dropout | AlphaDropout (0.1) |
| Ensemble | Top-10 from 40-fold CV |
| Parameters | ~19M total |
## Molecular Features (11,377 total)
| Feature | Dimensions | Description |
|---------|------------|-------------|
| ECFP6 | 8,192 | Morgan fingerprints (radius 3) |
| MACCS | 167 | Structural keys |
| RDKit | 208 | Physicochemical descriptors |
| Toxicophores | 1,868 | Toxicity structural alerts |
| Filters | 815 | PAINS, BRENK, NIH, ZINC |
| Similarity | 127 | Target ligand similarity |
## Tox21 Endpoints
### Nuclear Receptor Panel
- **NR-AR**: Androgen Receptor
- **NR-AR-LBD**: AR Ligand Binding Domain
- **NR-AhR**: Aryl Hydrocarbon Receptor
- **NR-Aromatase**: CYP19A1 Enzyme
- **NR-ER**: Estrogen Receptor
- **NR-ER-LBD**: ER Ligand Binding Domain
- **NR-PPAR-gamma**: Peroxisome Proliferator-Activated Receptor
### Stress Response Panel
- **SR-ARE**: Antioxidant Response Element
- **SR-ATAD5**: DNA Damage Response
- **SR-HSE**: Heat Shock Element
- **SR-MMP**: Mitochondrial Membrane Potential
- **SR-p53**: Tumor Suppressor p53
## Risk Interpretation
| Score | Risk Level |
|-------|------------|
| < 20% | Minimal |
| 20-40% | Low |
| 40-70% | Moderate |
| ≥ 70% | High |
---
Built by [Rasayan Labs](https://rasayan.ai)
""")
gr.Markdown("""
---
**API Endpoints**: `/predict` (POST), `/metadata` (GET), `/health` (GET)
""")
app = FastAPI()
class PredictRequest(BaseModel):
smiles: List[str] = Field(..., min_length=1, max_length=1000)
class PredictResponse(BaseModel):
predictions: Dict[str, Dict[str, float]]
model_info: Dict[str, Any]
class MetadataResponse(BaseModel):
model_name: str
version: str
max_batch_size: int
tox_endpoints: List[str]
description: str
@app.get("/metadata", response_model=MetadataResponse)
def get_metadata():
return {
"model_name": "Rasayan Tox21 SNN Ensemble",
"version": "1.0.0",
"max_batch_size": MAX_BATCH_SIZE,
"tox_endpoints": TASKS,
"description": "10-fold ensemble of Self-Normalizing Neural Networks trained on Tox21 Challenge data. Features: ECFP6, MACCS, RDKit descriptors, toxicophores, and target similarity."
}
@app.post("/predict", response_model=PredictResponse)
def predict(request: PredictRequest):
smiles_list = request.smiles
if len(smiles_list) > 1000:
raise HTTPException(status_code=400, detail="Maximum 1000 SMILES per request")
if len(smiles_list) == 0:
raise HTTPException(status_code=400, detail="At least 1 SMILES required")
try:
features_dict, valid = extractor.extract_features(smiles_list)
features = np.concatenate(
[features_dict[k] for k in FEATURE_KEYS if k in features_dict],
axis=1
)
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
probs = ensemble.predict(features)
predictions = {}
for i, smi in enumerate(smiles_list):
if valid[i]:
predictions[smi] = {
task: float(probs[i, j]) for j, task in enumerate(TASKS)
}
else:
predictions[smi] = {task: 0.5 for task in TASKS}
return {
"predictions": predictions,
"model_info": {
"name": "Rasayan Tox21 SNN Ensemble",
"version": "1.0.0"
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
def health():
return {"status": "ok"}
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)