import sys from pathlib import Path from typing import List, Dict, Any ROOT = Path(__file__).parent sys.path.insert(0, str(ROOT)) import gradio as gr from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field import numpy as np from src import EnhancedFeatureExtractor, Tox21Ensemble TASKS = [ "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53" ] TASK_DESCRIPTIONS = { "NR-AR": "Androgen Receptor", "NR-AR-LBD": "Androgen Receptor LBD", "NR-AhR": "Aryl Hydrocarbon Receptor", "NR-Aromatase": "Aromatase (CYP19A1)", "NR-ER": "Estrogen Receptor", "NR-ER-LBD": "Estrogen Receptor LBD", "NR-PPAR-gamma": "PPARγ", "SR-ARE": "Antioxidant Response", "SR-ATAD5": "DNA Damage (ATAD5)", "SR-HSE": "Heat Shock Response", "SR-MMP": "Mitochondrial Toxicity", "SR-p53": "Genotoxicity (p53)" } FEATURE_KEYS = [ "ecfps", "maccs", "rdkit_descrs", "tox", "rdkit_filters", "similarity", "max_similarity", "db_similarity" ] MAX_BATCH_SIZE = 256 print("Loading model...") extractor = EnhancedFeatureExtractor( toxicophores_path=ROOT / "data" / "toxicophores_validated.json", db_ligands_path=ROOT / "data" / "target_ligands_validated.json", ) ensemble = Tox21Ensemble(ROOT / "checkpoints" / "ensemble.pt") print("Model loaded successfully!") def predict_toxicity(smiles_input: str) -> tuple: if not smiles_input.strip(): return None, "Please enter at least one SMILES" lines = [s.strip() for s in smiles_input.strip().split('\n') if s.strip()] if len(lines) > 100: return None, "Maximum 100 molecules per request" try: features_dict, valid = extractor.extract_features(lines) features = np.concatenate( [features_dict[k] for k in FEATURE_KEYS if k in features_dict], axis=1 ) features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) probs = ensemble.predict(features) results = [] for i, smi in enumerate(lines): if valid[i]: row = {"SMILES": smi[:50] + "..." if len(smi) > 50 else smi} for j, task in enumerate(TASKS): score = float(probs[i, j]) row[task] = f"{score:.1%}" results.append(row) else: row = {"SMILES": smi[:50] + "..." if len(smi) > 50 else smi} for task in TASKS: row[task] = "Invalid" results.append(row) import pandas as pd df = pd.DataFrame(results) return df, f"Processed {len(lines)} molecule(s)" except Exception as e: return None, f"Error: {str(e)}" def predict_single(smiles: str) -> str: if not smiles.strip(): return "Enter a SMILES string" try: features_dict, valid = extractor.extract_features([smiles]) if not valid[0]: return "Invalid SMILES structure" features = np.concatenate( [features_dict[k] for k in FEATURE_KEYS if k in features_dict], axis=1 ) features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) probs = ensemble.predict(features) lines = [] lines.append("═" * 45) lines.append(" TOXICITY PREDICTION RESULTS") lines.append("═" * 45) sorted_results = sorted( [(task, float(probs[0, j])) for j, task in enumerate(TASKS)], key=lambda x: -x[1] ) for task, score in sorted_results: desc = TASK_DESCRIPTIONS[task] bar_len = int(score * 20) bar = "█" * bar_len + "░" * (20 - bar_len) if score >= 0.7: risk = "HIGH" elif score >= 0.4: risk = "MED " elif score >= 0.2: risk = "LOW " else: risk = "MIN " lines.append(f"{task:15} {bar} {score:5.1%} [{risk}]") lines.append(f" └─ {desc}") lines.append("═" * 45) return "\n".join(lines) except Exception as e: return f"Error: {str(e)}" EXAMPLES = [ ["CCO"], ["CC(=O)Nc1ccc(O)cc1"], ["c1ccc2c(c1)cc3ccc4cccc5ccc2c3c45"], ["CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C"], ["CC12CCC3c4ccc(O)cc4CCC3C1CCC2O"], ] with gr.Blocks( title="Rasayan Tox21 Classifier", theme=gr.themes.Soft() ) as demo: gr.Markdown(""" # ☠️ Rasayan Tox21 Classifier Predict molecular toxicity across **12 Tox21 endpoints** using a Self-Normalizing Neural Network ensemble. | Model | Features | Training | |-------|----------|----------| | 10-fold SNN Ensemble | 11,377 molecular descriptors | 40-fold CV, AUC: 0.882 | """) with gr.Tabs(): with gr.TabItem("Single Molecule"): with gr.Row(): with gr.Column(scale=1): single_input = gr.Textbox( label="SMILES", placeholder="Enter SMILES (e.g., CCO for ethanol)", lines=1 ) single_btn = gr.Button("Predict", variant="primary") gr.Examples( examples=EXAMPLES, inputs=single_input, label="Example Molecules" ) with gr.Column(scale=2): single_output = gr.Textbox( label="Toxicity Profile", lines=30 ) single_btn.click( fn=predict_single, inputs=single_input, outputs=single_output ) with gr.TabItem("Batch Processing"): gr.Markdown("Enter multiple SMILES (one per line, max 100)") batch_input = gr.Textbox( label="SMILES List", placeholder="CCO\nCC(=O)Nc1ccc(O)cc1\nc1ccccc1", lines=5 ) batch_btn = gr.Button("Process Batch", variant="primary") batch_status = gr.Textbox(label="Status", lines=1) batch_output = gr.Dataframe( label="Results", wrap=True ) batch_btn.click( fn=predict_toxicity, inputs=batch_input, outputs=[batch_output, batch_status] ) with gr.TabItem("About"): gr.Markdown(""" ## Model Architecture **Self-Normalizing Neural Networks (SNNs)** with SELU activation and AlphaDropout. | Component | Details | |-----------|---------| | Hidden Layers | 8 × 768 units | | Activation | SELU | | Dropout | AlphaDropout (0.1) | | Ensemble | Top-10 from 40-fold CV | | Parameters | ~19M total | ## Molecular Features (11,377 total) | Feature | Dimensions | Description | |---------|------------|-------------| | ECFP6 | 8,192 | Morgan fingerprints (radius 3) | | MACCS | 167 | Structural keys | | RDKit | 208 | Physicochemical descriptors | | Toxicophores | 1,868 | Toxicity structural alerts | | Filters | 815 | PAINS, BRENK, NIH, ZINC | | Similarity | 127 | Target ligand similarity | ## Tox21 Endpoints ### Nuclear Receptor Panel - **NR-AR**: Androgen Receptor - **NR-AR-LBD**: AR Ligand Binding Domain - **NR-AhR**: Aryl Hydrocarbon Receptor - **NR-Aromatase**: CYP19A1 Enzyme - **NR-ER**: Estrogen Receptor - **NR-ER-LBD**: ER Ligand Binding Domain - **NR-PPAR-gamma**: Peroxisome Proliferator-Activated Receptor ### Stress Response Panel - **SR-ARE**: Antioxidant Response Element - **SR-ATAD5**: DNA Damage Response - **SR-HSE**: Heat Shock Element - **SR-MMP**: Mitochondrial Membrane Potential - **SR-p53**: Tumor Suppressor p53 ## Risk Interpretation | Score | Risk Level | |-------|------------| | < 20% | Minimal | | 20-40% | Low | | 40-70% | Moderate | | ≥ 70% | High | --- Built by [Rasayan Labs](https://rasayan.ai) """) gr.Markdown(""" --- **API Endpoints**: `/predict` (POST), `/metadata` (GET), `/health` (GET) """) app = FastAPI() class PredictRequest(BaseModel): smiles: List[str] = Field(..., min_length=1, max_length=1000) class PredictResponse(BaseModel): predictions: Dict[str, Dict[str, float]] model_info: Dict[str, Any] class MetadataResponse(BaseModel): model_name: str version: str max_batch_size: int tox_endpoints: List[str] description: str @app.get("/metadata", response_model=MetadataResponse) def get_metadata(): return { "model_name": "Rasayan Tox21 SNN Ensemble", "version": "1.0.0", "max_batch_size": MAX_BATCH_SIZE, "tox_endpoints": TASKS, "description": "10-fold ensemble of Self-Normalizing Neural Networks trained on Tox21 Challenge data. Features: ECFP6, MACCS, RDKit descriptors, toxicophores, and target similarity." } @app.post("/predict", response_model=PredictResponse) def predict(request: PredictRequest): smiles_list = request.smiles if len(smiles_list) > 1000: raise HTTPException(status_code=400, detail="Maximum 1000 SMILES per request") if len(smiles_list) == 0: raise HTTPException(status_code=400, detail="At least 1 SMILES required") try: features_dict, valid = extractor.extract_features(smiles_list) features = np.concatenate( [features_dict[k] for k in FEATURE_KEYS if k in features_dict], axis=1 ) features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) probs = ensemble.predict(features) predictions = {} for i, smi in enumerate(smiles_list): if valid[i]: predictions[smi] = { task: float(probs[i, j]) for j, task in enumerate(TASKS) } else: predictions[smi] = {task: 0.5 for task in TASKS} return { "predictions": predictions, "model_info": { "name": "Rasayan Tox21 SNN Ensemble", "version": "1.0.0" } } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") def health(): return {"status": "ok"} app = gr.mount_gradio_app(app, demo, path="/") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)