arittrabag's picture
Create app.py
d04b4c3 verified
import shap
import gradio as gr
import numpy as np
import shap
import joblib
import pandas as pd
import matplotlib.pyplot as plt
model = joblib.load('gdsc_xgboost_model.pkl')
def predict_ic50(AUC, Z_SCORE, DRUG_ID, TARGET, TARGET_PATHWAY, Growth_Properties_Suspension):
# Create DataFrame for input
input_data = pd.DataFrame([{
'AUC': AUC,
'Z_SCORE': Z_SCORE,
'DRUG_ID': DRUG_ID,
'TARGET': TARGET,
'TARGET_PATHWAY': TARGET_PATHWAY,
'Growth Properties_Suspension': Growth_Properties_Suspension
}])
# One-hot encode categorical features if necessary
input_data = pd.get_dummies(input_data)
# Align input with model features
model_features = model.get_booster().feature_names
for feature in model_features:
if feature not in input_data.columns:
input_data[feature] = 0 # Add missing features with 0
input_data = input_data[model_features]
# Predict IC50
ic50_pred = model.predict(input_data)[0]
# SHAP Explanation
explainer = shap.Explainer(model)
shap_values = explainer(input_data)
# Plot SHAP explanation
plt.figure(figsize=(10, 6))
shap.plots.waterfall(shap_values[0], max_display=10)
plt.title("SHAP Explanation for Prediction")
plt.savefig("shap_plot.png")
plt.close()
return f"Predicted LN_IC50: {ic50_pred:.3f}", "shap_plot.png"
inputs = [
gr.Number(label="AUC (0.5 - 1.5)", value=0.85, info="Area Under Curve - Typically 0.5 to 1.5"),
gr.Number(label="Z_SCORE (-2 to 2)", value=0.45, info="Z-Score for dose-response curve"),
gr.Number(label="DRUG_ID (Numeric Code)", value=1003, info="Unique identifier for the drug"),
gr.Textbox(label="TARGET", value="MTORC1", placeholder="e.g., MTORC1", info="Gene or protein targeted by the drug"),
gr.Textbox(label="TARGET_PATHWAY", value="PI3K/MTOR signaling", placeholder="e.g., PI3K/MTOR signaling", info="Biological pathway affected"),
gr.Checkbox(label="Growth Properties - Suspension", value=False, info="Check if cells grow in suspension")
]
outputs = [
gr.Textbox(label="Predicted LN_IC50"),
gr.Image(label="SHAP Explanation")
]
gr.Interface(
fn=predict_ic50,
inputs=inputs,
outputs=outputs,
title="GDSC Drug Sensitivity Predictor",
description="Predict LN_IC50 for cancer drug response and visualize feature impact using SHAP. Please follow the input guidelines for accurate predictions.",
theme="default"
).launch()