Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import shap
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
+
import shap
|
| 5 |
+
import joblib
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
|
| 9 |
+
model = joblib.load('gdsc_xgboost_model.pkl')
|
| 10 |
+
|
| 11 |
+
def predict_ic50(AUC, Z_SCORE, DRUG_ID, TARGET, TARGET_PATHWAY, Growth_Properties_Suspension):
|
| 12 |
+
# Create DataFrame for input
|
| 13 |
+
input_data = pd.DataFrame([{
|
| 14 |
+
'AUC': AUC,
|
| 15 |
+
'Z_SCORE': Z_SCORE,
|
| 16 |
+
'DRUG_ID': DRUG_ID,
|
| 17 |
+
'TARGET': TARGET,
|
| 18 |
+
'TARGET_PATHWAY': TARGET_PATHWAY,
|
| 19 |
+
'Growth Properties_Suspension': Growth_Properties_Suspension
|
| 20 |
+
}])
|
| 21 |
+
|
| 22 |
+
# One-hot encode categorical features if necessary
|
| 23 |
+
input_data = pd.get_dummies(input_data)
|
| 24 |
+
|
| 25 |
+
# Align input with model features
|
| 26 |
+
model_features = model.get_booster().feature_names
|
| 27 |
+
for feature in model_features:
|
| 28 |
+
if feature not in input_data.columns:
|
| 29 |
+
input_data[feature] = 0 # Add missing features with 0
|
| 30 |
+
|
| 31 |
+
input_data = input_data[model_features]
|
| 32 |
+
|
| 33 |
+
# Predict IC50
|
| 34 |
+
ic50_pred = model.predict(input_data)[0]
|
| 35 |
+
|
| 36 |
+
# SHAP Explanation
|
| 37 |
+
explainer = shap.Explainer(model)
|
| 38 |
+
shap_values = explainer(input_data)
|
| 39 |
+
|
| 40 |
+
# Plot SHAP explanation
|
| 41 |
+
plt.figure(figsize=(10, 6))
|
| 42 |
+
shap.plots.waterfall(shap_values[0], max_display=10)
|
| 43 |
+
plt.title("SHAP Explanation for Prediction")
|
| 44 |
+
plt.savefig("shap_plot.png")
|
| 45 |
+
plt.close()
|
| 46 |
+
|
| 47 |
+
return f"Predicted LN_IC50: {ic50_pred:.3f}", "shap_plot.png"
|
| 48 |
+
|
| 49 |
+
inputs = [
|
| 50 |
+
gr.Number(label="AUC (0.5 - 1.5)", value=0.85, info="Area Under Curve - Typically 0.5 to 1.5"),
|
| 51 |
+
gr.Number(label="Z_SCORE (-2 to 2)", value=0.45, info="Z-Score for dose-response curve"),
|
| 52 |
+
gr.Number(label="DRUG_ID (Numeric Code)", value=1003, info="Unique identifier for the drug"),
|
| 53 |
+
gr.Textbox(label="TARGET", value="MTORC1", placeholder="e.g., MTORC1", info="Gene or protein targeted by the drug"),
|
| 54 |
+
gr.Textbox(label="TARGET_PATHWAY", value="PI3K/MTOR signaling", placeholder="e.g., PI3K/MTOR signaling", info="Biological pathway affected"),
|
| 55 |
+
gr.Checkbox(label="Growth Properties - Suspension", value=False, info="Check if cells grow in suspension")
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
outputs = [
|
| 59 |
+
gr.Textbox(label="Predicted LN_IC50"),
|
| 60 |
+
gr.Image(label="SHAP Explanation")
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
gr.Interface(
|
| 64 |
+
fn=predict_ic50,
|
| 65 |
+
inputs=inputs,
|
| 66 |
+
outputs=outputs,
|
| 67 |
+
title="GDSC Drug Sensitivity Predictor",
|
| 68 |
+
description="Predict LN_IC50 for cancer drug response and visualize feature impact using SHAP. Please follow the input guidelines for accurate predictions.",
|
| 69 |
+
theme="default"
|
| 70 |
+
).launch()
|