} commited on
Commit
08461c2
·
1 Parent(s): fa9fe39

Update with new results

Browse files
app.py CHANGED
@@ -6,7 +6,7 @@ from utils.prediction import predict_drug
6
 
7
  # Loading the model
8
  filename = "./model/drug_pipeline.sav"
9
- pipe = pickle.load(open(filename, 'rb'))
10
 
11
 
12
  # Web interface section
@@ -27,15 +27,22 @@ with st.sidebar.expander("Single Prediction"):
27
  cholesterol_list = np.array(["HIGH", "NORMAL"])
28
  cholesterol = st.radio("Select your Cholesterol", cholesterol_list)
29
 
30
- Na_to_k = st.number_input(label="NA_to_K", min_value=6.2, step=0.1, max_value=38.2)
 
 
31
  submit_button = st.form_submit_button(label="Predict")
32
 
33
  if submit_button:
34
- prediction = predict_drug(age, gender[0], blood_pressure, cholesterol, Na_to_k, pipe)
 
 
35
 
36
  if prediction is None:
37
  st.error("An error occurred while getting the prediction!")
38
 
39
  message = f"The Drug is {prediction}!"
40
- message_color = 'red' if prediction == 1 else 'green'
41
- st.markdown(f"<h3 style='text-align: left;color:{message_color}'> {(message)} </h3>", unsafe_allow_html=True)
 
 
 
 
6
 
7
  # Loading the model
8
  filename = "./model/drug_pipeline.sav"
9
+ pipe = pickle.load(open(filename, "rb"))
10
 
11
 
12
  # Web interface section
 
27
  cholesterol_list = np.array(["HIGH", "NORMAL"])
28
  cholesterol = st.radio("Select your Cholesterol", cholesterol_list)
29
 
30
+ Na_to_k = st.number_input(
31
+ label="NA_to_K", min_value=6.2, step=0.1, max_value=38.2
32
+ )
33
  submit_button = st.form_submit_button(label="Predict")
34
 
35
  if submit_button:
36
+ prediction = predict_drug(
37
+ age, gender[0], blood_pressure, cholesterol, Na_to_k, pipe
38
+ )
39
 
40
  if prediction is None:
41
  st.error("An error occurred while getting the prediction!")
42
 
43
  message = f"The Drug is {prediction}!"
44
+ message_color = "red" if prediction == 1 else "green"
45
+ st.markdown(
46
+ f"<h3 style='text-align: left;color:{message_color}'> {(message)} </h3>",
47
+ unsafe_allow_html=True,
48
+ )
model/drug_pipeline.sav CHANGED
Binary files a/model/drug_pipeline.sav and b/model/drug_pipeline.sav differ
 
results/metrics.txt CHANGED
@@ -1,2 +1,2 @@
1
 
2
- Accuracy=0.95, F1_score = 0.89
 
1
 
2
+ Accuracy=0.97, F1_score = 0.85
results/model_result.png CHANGED
train.py CHANGED
@@ -17,7 +17,9 @@ from sklearn.model_selection import train_test_split
17
  X = drug_df.drop("Drug", axis=1).values
18
  y = drug_df.Drug.values
19
 
20
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=125)
 
 
21
 
22
  # Pipeline
23
  cat_col = [1, 2, 3]
@@ -32,7 +34,7 @@ transform = ColumnTransformer(
32
  )
33
 
34
  pipe = Pipeline(
35
- steps = [
36
  ("preprocessing", transform),
37
  ("model", RandomForestClassifier(n_estimators=10, random_state=125)),
38
  ]
@@ -46,7 +48,7 @@ predictions = pipe.predict(X_test)
46
  accuracy = accuracy_score(y_test, predictions)
47
  f1 = f1_score(y_test, predictions, average="macro")
48
 
49
- print("Accuracy: ", str(round(accuracy, 2)*100)+"%","F1: ", round(f1, 2))
50
 
51
  # Confusion matrix
52
  import matplotlib.pyplot as plt
@@ -64,8 +66,9 @@ with open("./results/metrics.txt", "w") as outfile:
64
 
65
  # Save the model
66
  import pickle
 
67
  # save the model to disk
68
- filename = './model/drug_pipeline.sav'
69
- pickle.dump(pipe, open(filename, 'wb'))
70
 
71
- #sio.dump(pipe, "./model/drug_pipeline.skops")
 
17
  X = drug_df.drop("Drug", axis=1).values
18
  y = drug_df.Drug.values
19
 
20
+ X_train, X_test, y_train, y_test = train_test_split(
21
+ X, y, test_size=0.3, random_state=125
22
+ )
23
 
24
  # Pipeline
25
  cat_col = [1, 2, 3]
 
34
  )
35
 
36
  pipe = Pipeline(
37
+ steps=[
38
  ("preprocessing", transform),
39
  ("model", RandomForestClassifier(n_estimators=10, random_state=125)),
40
  ]
 
48
  accuracy = accuracy_score(y_test, predictions)
49
  f1 = f1_score(y_test, predictions, average="macro")
50
 
51
+ print("Accuracy: ", str(round(accuracy, 2) * 100) + "%", "F1: ", round(f1, 2))
52
 
53
  # Confusion matrix
54
  import matplotlib.pyplot as plt
 
66
 
67
  # Save the model
68
  import pickle
69
+
70
  # save the model to disk
71
+ filename = "./model/drug_pipeline.sav"
72
+ pickle.dump(pipe, open(filename, "wb"))
73
 
74
+ # sio.dump(pipe, "./model/drug_pipeline.skops")