nicojara commited on
Commit
b5e5269
·
verified ·
1 Parent(s): ad8fa03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -55
app.py CHANGED
@@ -7,18 +7,16 @@ import numpy as np
7
  import matplotlib.pyplot as plt
8
 
9
  # load the model from disk
10
- loaded_model = pickle.load(open("glioma_xgb.pkl", 'rb'))
11
 
12
  # Setup SHAP
13
  explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
14
 
15
  # Create the main function for server
16
- def main_func(Gender, Age_at_diagnosis, IDH1, TP53, ATRX, PTEN, EGFR, CIC, MUC16, PIK3CA, NF1, PIK3R1, FUBP1, RB1, NOTCH1, BCOR, CSMD3, SMARCA4, GRIN2A, IDH2, FAT4, PDGFRA):
17
- new_row = pd.DataFrame.from_dict({'Gender':Gender,
18
- 'Age_at_diagnosis':Age_at_diagnosis,'IDH1':IDH1,'TP53':TP53,
19
- 'ATRX':ATRX, 'PTEN':PTEN,'EGFR':EGFR,'CIC':CIC,
20
- 'MUC16':MUC16,'PIK3CA':PIK3CA,'NF1':NF1,'PIK3R1':PIK3R1, 'FUBP1': FUBP1, 'RB1': RB1, 'NOTCH1': NOTCH1,
21
- 'BCOR': BCOR, 'CSMD3': CSMD3, 'SMARCA4': SMARCA4, 'GRIN2A': GRIN2A, 'IDH2': IDH2, 'FAT4': FAT4, 'PDGFRA': PDGFRA},
22
  orient = 'index').transpose()
23
 
24
  prob = loaded_model.predict_proba(new_row)
@@ -32,72 +30,54 @@ def main_func(Gender, Age_at_diagnosis, IDH1, TP53, ATRX, PTEN, EGFR, CIC, MUC16
32
  local_plot = plt.gcf()
33
  plt.close()
34
 
35
- return {"Chance of Having GBM Tumor": 1-float(prob[0][0]), "Chance of Having LGG Tumor": float(prob[0][0])}, local_plot
 
 
 
36
 
37
  # Create the UI
38
- title = "**Glioma Predictor & Interpreter** 🪐"
39
- description1 = """This app takes info from subjects and predicts the severity of their brain tumor (LGG or GBM). Do not use for medical diagnosis."""
40
 
41
- description2 = """
42
- To use the app, click on one of the examples, or adjust the values of the factors, and click on Analyze. 🤞
43
- """
44
 
45
  with gr.Blocks(title=title) as demo:
46
  gr.Markdown(f"## {title}")
47
  gr.Markdown(description1)
48
- gr.Markdown("""---""")
49
  gr.Markdown(description2)
50
- gr.Markdown("""---""")
51
 
52
  with gr.Row():
53
- Gender = gr.Radio(["Female", "Male"], label="Gender", type = "index")
54
- Age_at_diagnosis = gr.Number(label="Age at Diagnosis")
55
  with gr.Row():
56
- IDH1 = gr.Radio(["No", "Yes"], label="IDH1 Mutation", type="index")
57
- TP53 = gr.Radio(["No", "Yes"], label="TP53 Mutation", type="index")
58
- ATRX = gr.Radio(["No", "Yes"], label="ATRX Mutation", type="index")
59
- with gr.Row():
60
- PTEN = gr.Radio(["No", "Yes"], label="PTEN Mutation", type="index")
61
- EGFR = gr.Radio(["No", "Yes"], label="EGFR Mutation", type="index")
62
- CIC = gr.Radio(["No", "Yes"], label="CIC Mutation", type="index")
63
- with gr.Row():
64
- MUC16 = gr.Radio(["No", "Yes"], label="MUC16 Mutation", type="index")
65
- PIK3CA = gr.Radio(["No", "Yes"], label="PIK3CA Mutation", type="index")
66
- NF1 = gr.Radio(["No", "Yes"], label="NF1 Mutation", type="index")
67
- with gr.Row():
68
- PIK3R1 = gr.Radio(["No", "Yes"], label="PIK3R1 Mutation", type="index")
69
- FUBP1 = gr.Radio(["No", "Yes"], label="FUBP1 Mutation", type="index")
70
- RB1 = gr.Radio(["No", "Yes"], label="RB1 Mutation", type="index")
71
- with gr.Row():
72
- NOTCH1 = gr.Radio(["No", "Yes"], label="NOTCH1 Mutation", type="index")
73
- BCOR = gr.Radio(["No", "Yes"], label="BCOR Mutation", type="index")
74
- CSMD3 = gr.Radio(["No", "Yes"], label="CSMD3 Mutation", type="index")
75
- with gr.Row():
76
- SMARCA4 = gr.Radio(["No", "Yes"], label="SMAECA4 Mutation", type="index")
77
- GRIN2A = gr.Radio(["No", "Yes"], label="GRIN2A Mutation", type="index")
78
- IDH2 = gr.Radio(["No", "Yes"], label="IDH2 Mutation", type="index")
79
- FAT4 = gr.Radio(["No", "Yes"], label="FAT4 Mutation", type="index")
80
- PDGFRA = gr.Radio(["No", "Yes"], label="PDGFRA Mutation", type="index")
81
-
82
-
83
 
84
-
85
-
86
-
87
-
88
  submit_btn = gr.Button("Analyze")
89
 
90
  with gr.Column(visible=True) as output_col:
91
- label = gr.Label(label = "Predicted Label")
92
- local_plot = gr.Plot(label = 'Grade:')
93
 
94
  submit_btn.click(
95
  main_func,
96
- [Gender, Age_at_diagnosis, IDH1, TP53, ATRX, PTEN, EGFR, CIC, MUC16, PIK3CA, NF1, PIK3R1, FUBP1, RB1, NOTCH1, BCOR, CSMD3, SMARCA4, GRIN2A, IDH2, FAT4, PDGFRA],
97
- [label,local_plot], api_name="Glioma_Predictor"
 
98
  )
99
-
100
- gr.Markdown("### Click on any of the examples below to see how it works:")
101
- gr.Examples([["Male",24,"Yes","No","Yes","Yes","Yes","No","Yes","Yes","Yes","Yes","Yes","No","No","No","No","Yes","No","Yes","No","Yes"], ["Male",70,"No","No","No","No","No","No","No","No","No","Yes","No","Yes","No","No","No","No","No","No","No", "No"]], [Gender, Age_at_diagnosis, IDH1, TP53, ATRX, PTEN, EGFR, CIC, MUC16, PIK3CA, NF1, PIK3R1, FUBP1, RB1, NOTCH1, BCOR, CSMD3, SMARCA4, GRIN2A, IDH2, FAT4, PDGFRA], [label,local_plot], main_func, cache_examples=True)
 
 
 
 
 
 
 
102
 
103
  demo.launch()
 
7
  import matplotlib.pyplot as plt
8
 
9
  # load the model from disk
10
+ loaded_model = pickle.load(open("salar_xgb_team.pkl", 'rb'))
11
 
12
  # Setup SHAP
13
  explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
14
 
15
  # Create the main function for server
16
+ def main_func(age, education-num, sex, capital-gain, capital-loss, hours-per-week, salary-class):
17
+ new_row = pd.DataFrame.from_dict({'age':age,
18
+ 'education-num':education-num,'sex':sex,'capital-gain':capital-gain,
19
+ 'capital-loss':capital-loss, 'hours-per-week':hours-per-week,'salary-class':salary-class},
 
 
20
  orient = 'index').transpose()
21
 
22
  prob = loaded_model.predict_proba(new_row)
 
30
  local_plot = plt.gcf()
31
  plt.close()
32
 
33
+ return {
34
+ "Chance of Earning > $50K": float(prob[0][1]),
35
+ "Chance of Earning ≤ $50K": float(prob[0][0])
36
+ }, local_plot
37
 
38
  # Create the UI
39
+ title = "**Household Income Predictor** 💰"
40
+ description1 = """This app uses your input to predict whether a household earns more or less than $50K per year"""
41
 
42
+ description2 = """Adjust the values below and click 'Analyze' to see the prediction and explanation."""
 
 
43
 
44
  with gr.Blocks(title=title) as demo:
45
  gr.Markdown(f"## {title}")
46
  gr.Markdown(description1)
47
+ gr.Markdown("---")
48
  gr.Markdown(description2)
49
+ gr.Markdown("---")
50
 
51
  with gr.Row():
52
+ age = gr.Number(label="Age", value=35)
53
+ education_num = gr.Number(label="Education Level (numeric)", value=10)
54
  with gr.Row():
55
+ sex = gr.Radio(["Male", "Female"], label="Sex", type="index") # Male = 0, Female = 1
56
+ capital_gain = gr.Number(label="Capital Gain", value=0)
57
+ capital_loss = gr.Number(label="Capital Loss", value=0)
58
+ hours_per_week = gr.Number(label="Hours per Week", value=40)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
 
 
 
 
60
  submit_btn = gr.Button("Analyze")
61
 
62
  with gr.Column(visible=True) as output_col:
63
+ label = gr.Label(label="Predicted Probabilities")
64
+ local_plot = gr.Plot(label="Top SHAP Features")
65
 
66
  submit_btn.click(
67
  main_func,
68
+ [age, education_num, sex, capital_gain, capital_loss, hours_per_week],
69
+ [label, local_plot],
70
+ api_name="Salary_Predictor"
71
  )
72
+
73
+ gr.Markdown("### Examples:")
74
+ gr.Examples([
75
+ [28, 12, 0, 0, 0, 40], # Male, younger, more educated
76
+ [50, 9, 1, 0, 0, 30] # Female, mid-education, fewer hours
77
+ ],
78
+ inputs=[age, education_num, sex, capital_gain, capital_loss, hours_per_week],
79
+ outputs=[label, local_plot],
80
+ fn=main_func,
81
+ cache_examples=True)
82
 
83
  demo.launch()