File size: 2,490 Bytes
d04b4c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import shap
import gradio as gr
import numpy as np
import shap
import joblib
import pandas as pd
import matplotlib.pyplot as plt

model = joblib.load('gdsc_xgboost_model.pkl')

def predict_ic50(AUC, Z_SCORE, DRUG_ID, TARGET, TARGET_PATHWAY, Growth_Properties_Suspension):
    # Create DataFrame for input
    input_data = pd.DataFrame([{
        'AUC': AUC,
        'Z_SCORE': Z_SCORE,
        'DRUG_ID': DRUG_ID,
        'TARGET': TARGET,
        'TARGET_PATHWAY': TARGET_PATHWAY,
        'Growth Properties_Suspension': Growth_Properties_Suspension
    }])

    # One-hot encode categorical features if necessary
    input_data = pd.get_dummies(input_data)

    # Align input with model features
    model_features = model.get_booster().feature_names
    for feature in model_features:
        if feature not in input_data.columns:
            input_data[feature] = 0  # Add missing features with 0

    input_data = input_data[model_features]

    # Predict IC50
    ic50_pred = model.predict(input_data)[0]

    # SHAP Explanation
    explainer = shap.Explainer(model)
    shap_values = explainer(input_data)

    # Plot SHAP explanation
    plt.figure(figsize=(10, 6))
    shap.plots.waterfall(shap_values[0], max_display=10)
    plt.title("SHAP Explanation for Prediction")
    plt.savefig("shap_plot.png")
    plt.close()

    return f"Predicted LN_IC50: {ic50_pred:.3f}", "shap_plot.png"

inputs = [
    gr.Number(label="AUC (0.5 - 1.5)", value=0.85, info="Area Under Curve - Typically 0.5 to 1.5"),
    gr.Number(label="Z_SCORE (-2 to 2)", value=0.45, info="Z-Score for dose-response curve"),
    gr.Number(label="DRUG_ID (Numeric Code)", value=1003, info="Unique identifier for the drug"),
    gr.Textbox(label="TARGET", value="MTORC1", placeholder="e.g., MTORC1", info="Gene or protein targeted by the drug"),
    gr.Textbox(label="TARGET_PATHWAY", value="PI3K/MTOR signaling", placeholder="e.g., PI3K/MTOR signaling", info="Biological pathway affected"),
    gr.Checkbox(label="Growth Properties - Suspension", value=False, info="Check if cells grow in suspension")
]

outputs = [
    gr.Textbox(label="Predicted LN_IC50"),
    gr.Image(label="SHAP Explanation")
]

gr.Interface(
    fn=predict_ic50,
    inputs=inputs,
    outputs=outputs,
    title="GDSC Drug Sensitivity Predictor",
    description="Predict LN_IC50 for cancer drug response and visualize feature impact using SHAP. Please follow the input guidelines for accurate predictions.",
    theme="default"
).launch()