Spaces:
Running
Running
| import sys | |
| from pathlib import Path | |
| from typing import List, Dict, Any | |
| ROOT = Path(__file__).parent | |
| sys.path.insert(0, str(ROOT)) | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| import numpy as np | |
| from src import EnhancedFeatureExtractor, Tox21Ensemble | |
| TASKS = [ | |
| "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", | |
| "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53" | |
| ] | |
| TASK_DESCRIPTIONS = { | |
| "NR-AR": "Androgen Receptor", | |
| "NR-AR-LBD": "Androgen Receptor LBD", | |
| "NR-AhR": "Aryl Hydrocarbon Receptor", | |
| "NR-Aromatase": "Aromatase (CYP19A1)", | |
| "NR-ER": "Estrogen Receptor", | |
| "NR-ER-LBD": "Estrogen Receptor LBD", | |
| "NR-PPAR-gamma": "PPARγ", | |
| "SR-ARE": "Antioxidant Response", | |
| "SR-ATAD5": "DNA Damage (ATAD5)", | |
| "SR-HSE": "Heat Shock Response", | |
| "SR-MMP": "Mitochondrial Toxicity", | |
| "SR-p53": "Genotoxicity (p53)" | |
| } | |
| FEATURE_KEYS = [ | |
| "ecfps", "maccs", "rdkit_descrs", "tox", "rdkit_filters", | |
| "similarity", "max_similarity", "db_similarity" | |
| ] | |
| MAX_BATCH_SIZE = 256 | |
| print("Loading model...") | |
| extractor = EnhancedFeatureExtractor( | |
| toxicophores_path=ROOT / "data" / "toxicophores_validated.json", | |
| db_ligands_path=ROOT / "data" / "target_ligands_validated.json", | |
| ) | |
| ensemble = Tox21Ensemble(ROOT / "checkpoints" / "ensemble.pt") | |
| print("Model loaded successfully!") | |
| def predict_toxicity(smiles_input: str) -> tuple: | |
| if not smiles_input.strip(): | |
| return None, "Please enter at least one SMILES" | |
| lines = [s.strip() for s in smiles_input.strip().split('\n') if s.strip()] | |
| if len(lines) > 100: | |
| return None, "Maximum 100 molecules per request" | |
| try: | |
| features_dict, valid = extractor.extract_features(lines) | |
| features = np.concatenate( | |
| [features_dict[k] for k in FEATURE_KEYS if k in features_dict], | |
| axis=1 | |
| ) | |
| features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) | |
| probs = ensemble.predict(features) | |
| results = [] | |
| for i, smi in enumerate(lines): | |
| if valid[i]: | |
| row = {"SMILES": smi[:50] + "..." if len(smi) > 50 else smi} | |
| for j, task in enumerate(TASKS): | |
| score = float(probs[i, j]) | |
| row[task] = f"{score:.1%}" | |
| results.append(row) | |
| else: | |
| row = {"SMILES": smi[:50] + "..." if len(smi) > 50 else smi} | |
| for task in TASKS: | |
| row[task] = "Invalid" | |
| results.append(row) | |
| import pandas as pd | |
| df = pd.DataFrame(results) | |
| return df, f"Processed {len(lines)} molecule(s)" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def predict_single(smiles: str) -> str: | |
| if not smiles.strip(): | |
| return "Enter a SMILES string" | |
| try: | |
| features_dict, valid = extractor.extract_features([smiles]) | |
| if not valid[0]: | |
| return "Invalid SMILES structure" | |
| features = np.concatenate( | |
| [features_dict[k] for k in FEATURE_KEYS if k in features_dict], | |
| axis=1 | |
| ) | |
| features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) | |
| probs = ensemble.predict(features) | |
| lines = [] | |
| lines.append("═" * 45) | |
| lines.append(" TOXICITY PREDICTION RESULTS") | |
| lines.append("═" * 45) | |
| sorted_results = sorted( | |
| [(task, float(probs[0, j])) for j, task in enumerate(TASKS)], | |
| key=lambda x: -x[1] | |
| ) | |
| for task, score in sorted_results: | |
| desc = TASK_DESCRIPTIONS[task] | |
| bar_len = int(score * 20) | |
| bar = "█" * bar_len + "░" * (20 - bar_len) | |
| if score >= 0.7: | |
| risk = "HIGH" | |
| elif score >= 0.4: | |
| risk = "MED " | |
| elif score >= 0.2: | |
| risk = "LOW " | |
| else: | |
| risk = "MIN " | |
| lines.append(f"{task:15} {bar} {score:5.1%} [{risk}]") | |
| lines.append(f" └─ {desc}") | |
| lines.append("═" * 45) | |
| return "\n".join(lines) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| EXAMPLES = [ | |
| ["CCO"], | |
| ["CC(=O)Nc1ccc(O)cc1"], | |
| ["c1ccc2c(c1)cc3ccc4cccc5ccc2c3c45"], | |
| ["CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C"], | |
| ["CC12CCC3c4ccc(O)cc4CCC3C1CCC2O"], | |
| ] | |
| with gr.Blocks( | |
| title="Rasayan Tox21 Classifier", | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| gr.Markdown(""" | |
| # ☠️ Rasayan Tox21 Classifier | |
| Predict molecular toxicity across **12 Tox21 endpoints** using a Self-Normalizing Neural Network ensemble. | |
| | Model | Features | Training | | |
| |-------|----------|----------| | |
| | 10-fold SNN Ensemble | 11,377 molecular descriptors | 40-fold CV, AUC: 0.882 | | |
| """) | |
| with gr.Tabs(): | |
| with gr.TabItem("Single Molecule"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| single_input = gr.Textbox( | |
| label="SMILES", | |
| placeholder="Enter SMILES (e.g., CCO for ethanol)", | |
| lines=1 | |
| ) | |
| single_btn = gr.Button("Predict", variant="primary") | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=single_input, | |
| label="Example Molecules" | |
| ) | |
| with gr.Column(scale=2): | |
| single_output = gr.Textbox( | |
| label="Toxicity Profile", | |
| lines=30 | |
| ) | |
| single_btn.click( | |
| fn=predict_single, | |
| inputs=single_input, | |
| outputs=single_output | |
| ) | |
| with gr.TabItem("Batch Processing"): | |
| gr.Markdown("Enter multiple SMILES (one per line, max 100)") | |
| batch_input = gr.Textbox( | |
| label="SMILES List", | |
| placeholder="CCO\nCC(=O)Nc1ccc(O)cc1\nc1ccccc1", | |
| lines=5 | |
| ) | |
| batch_btn = gr.Button("Process Batch", variant="primary") | |
| batch_status = gr.Textbox(label="Status", lines=1) | |
| batch_output = gr.Dataframe( | |
| label="Results", | |
| wrap=True | |
| ) | |
| batch_btn.click( | |
| fn=predict_toxicity, | |
| inputs=batch_input, | |
| outputs=[batch_output, batch_status] | |
| ) | |
| with gr.TabItem("About"): | |
| gr.Markdown(""" | |
| ## Model Architecture | |
| **Self-Normalizing Neural Networks (SNNs)** with SELU activation and AlphaDropout. | |
| | Component | Details | | |
| |-----------|---------| | |
| | Hidden Layers | 8 × 768 units | | |
| | Activation | SELU | | |
| | Dropout | AlphaDropout (0.1) | | |
| | Ensemble | Top-10 from 40-fold CV | | |
| | Parameters | ~19M total | | |
| ## Molecular Features (11,377 total) | |
| | Feature | Dimensions | Description | | |
| |---------|------------|-------------| | |
| | ECFP6 | 8,192 | Morgan fingerprints (radius 3) | | |
| | MACCS | 167 | Structural keys | | |
| | RDKit | 208 | Physicochemical descriptors | | |
| | Toxicophores | 1,868 | Toxicity structural alerts | | |
| | Filters | 815 | PAINS, BRENK, NIH, ZINC | | |
| | Similarity | 127 | Target ligand similarity | | |
| ## Tox21 Endpoints | |
| ### Nuclear Receptor Panel | |
| - **NR-AR**: Androgen Receptor | |
| - **NR-AR-LBD**: AR Ligand Binding Domain | |
| - **NR-AhR**: Aryl Hydrocarbon Receptor | |
| - **NR-Aromatase**: CYP19A1 Enzyme | |
| - **NR-ER**: Estrogen Receptor | |
| - **NR-ER-LBD**: ER Ligand Binding Domain | |
| - **NR-PPAR-gamma**: Peroxisome Proliferator-Activated Receptor | |
| ### Stress Response Panel | |
| - **SR-ARE**: Antioxidant Response Element | |
| - **SR-ATAD5**: DNA Damage Response | |
| - **SR-HSE**: Heat Shock Element | |
| - **SR-MMP**: Mitochondrial Membrane Potential | |
| - **SR-p53**: Tumor Suppressor p53 | |
| ## Risk Interpretation | |
| | Score | Risk Level | | |
| |-------|------------| | |
| | < 20% | Minimal | | |
| | 20-40% | Low | | |
| | 40-70% | Moderate | | |
| | ≥ 70% | High | | |
| --- | |
| Built by [Rasayan Labs](https://rasayan.ai) | |
| """) | |
| gr.Markdown(""" | |
| --- | |
| **API Endpoints**: `/predict` (POST), `/metadata` (GET), `/health` (GET) | |
| """) | |
| app = FastAPI() | |
| class PredictRequest(BaseModel): | |
| smiles: List[str] = Field(..., min_length=1, max_length=1000) | |
| class PredictResponse(BaseModel): | |
| predictions: Dict[str, Dict[str, float]] | |
| model_info: Dict[str, Any] | |
| class MetadataResponse(BaseModel): | |
| model_name: str | |
| version: str | |
| max_batch_size: int | |
| tox_endpoints: List[str] | |
| description: str | |
| def get_metadata(): | |
| return { | |
| "model_name": "Rasayan Tox21 SNN Ensemble", | |
| "version": "1.0.0", | |
| "max_batch_size": MAX_BATCH_SIZE, | |
| "tox_endpoints": TASKS, | |
| "description": "10-fold ensemble of Self-Normalizing Neural Networks trained on Tox21 Challenge data. Features: ECFP6, MACCS, RDKit descriptors, toxicophores, and target similarity." | |
| } | |
| def predict(request: PredictRequest): | |
| smiles_list = request.smiles | |
| if len(smiles_list) > 1000: | |
| raise HTTPException(status_code=400, detail="Maximum 1000 SMILES per request") | |
| if len(smiles_list) == 0: | |
| raise HTTPException(status_code=400, detail="At least 1 SMILES required") | |
| try: | |
| features_dict, valid = extractor.extract_features(smiles_list) | |
| features = np.concatenate( | |
| [features_dict[k] for k in FEATURE_KEYS if k in features_dict], | |
| axis=1 | |
| ) | |
| features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) | |
| probs = ensemble.predict(features) | |
| predictions = {} | |
| for i, smi in enumerate(smiles_list): | |
| if valid[i]: | |
| predictions[smi] = { | |
| task: float(probs[i, j]) for j, task in enumerate(TASKS) | |
| } | |
| else: | |
| predictions[smi] = {task: 0.5 for task in TASKS} | |
| return { | |
| "predictions": predictions, | |
| "model_info": { | |
| "name": "Rasayan Tox21 SNN Ensemble", | |
| "version": "1.0.0" | |
| } | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def health(): | |
| return {"status": "ok"} | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |