Spaces:
Sleeping
Sleeping
}
commited on
Commit
·
08461c2
1
Parent(s):
fa9fe39
Update with new results
Browse files- app.py +12 -5
- model/drug_pipeline.sav +0 -0
- results/metrics.txt +1 -1
- results/model_result.png +0 -0
- train.py +9 -6
app.py
CHANGED
|
@@ -6,7 +6,7 @@ from utils.prediction import predict_drug
|
|
| 6 |
|
| 7 |
# Loading the model
|
| 8 |
filename = "./model/drug_pipeline.sav"
|
| 9 |
-
pipe = pickle.load(open(filename,
|
| 10 |
|
| 11 |
|
| 12 |
# Web interface section
|
|
@@ -27,15 +27,22 @@ with st.sidebar.expander("Single Prediction"):
|
|
| 27 |
cholesterol_list = np.array(["HIGH", "NORMAL"])
|
| 28 |
cholesterol = st.radio("Select your Cholesterol", cholesterol_list)
|
| 29 |
|
| 30 |
-
Na_to_k = st.number_input(
|
|
|
|
|
|
|
| 31 |
submit_button = st.form_submit_button(label="Predict")
|
| 32 |
|
| 33 |
if submit_button:
|
| 34 |
-
prediction = predict_drug(
|
|
|
|
|
|
|
| 35 |
|
| 36 |
if prediction is None:
|
| 37 |
st.error("An error occurred while getting the prediction!")
|
| 38 |
|
| 39 |
message = f"The Drug is {prediction}!"
|
| 40 |
-
message_color =
|
| 41 |
-
st.markdown(
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# Loading the model
|
| 8 |
filename = "./model/drug_pipeline.sav"
|
| 9 |
+
pipe = pickle.load(open(filename, "rb"))
|
| 10 |
|
| 11 |
|
| 12 |
# Web interface section
|
|
|
|
| 27 |
cholesterol_list = np.array(["HIGH", "NORMAL"])
|
| 28 |
cholesterol = st.radio("Select your Cholesterol", cholesterol_list)
|
| 29 |
|
| 30 |
+
Na_to_k = st.number_input(
|
| 31 |
+
label="NA_to_K", min_value=6.2, step=0.1, max_value=38.2
|
| 32 |
+
)
|
| 33 |
submit_button = st.form_submit_button(label="Predict")
|
| 34 |
|
| 35 |
if submit_button:
|
| 36 |
+
prediction = predict_drug(
|
| 37 |
+
age, gender[0], blood_pressure, cholesterol, Na_to_k, pipe
|
| 38 |
+
)
|
| 39 |
|
| 40 |
if prediction is None:
|
| 41 |
st.error("An error occurred while getting the prediction!")
|
| 42 |
|
| 43 |
message = f"The Drug is {prediction}!"
|
| 44 |
+
message_color = "red" if prediction == 1 else "green"
|
| 45 |
+
st.markdown(
|
| 46 |
+
f"<h3 style='text-align: left;color:{message_color}'> {(message)} </h3>",
|
| 47 |
+
unsafe_allow_html=True,
|
| 48 |
+
)
|
model/drug_pipeline.sav
CHANGED
|
Binary files a/model/drug_pipeline.sav and b/model/drug_pipeline.sav differ
|
|
|
results/metrics.txt
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
|
| 2 |
-
Accuracy=0.
|
|
|
|
| 1 |
|
| 2 |
+
Accuracy=0.97, F1_score = 0.85
|
results/model_result.png
CHANGED
|
|
train.py
CHANGED
|
@@ -17,7 +17,9 @@ from sklearn.model_selection import train_test_split
|
|
| 17 |
X = drug_df.drop("Drug", axis=1).values
|
| 18 |
y = drug_df.Drug.values
|
| 19 |
|
| 20 |
-
X_train, X_test, y_train, y_test = train_test_split(
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# Pipeline
|
| 23 |
cat_col = [1, 2, 3]
|
|
@@ -32,7 +34,7 @@ transform = ColumnTransformer(
|
|
| 32 |
)
|
| 33 |
|
| 34 |
pipe = Pipeline(
|
| 35 |
-
steps
|
| 36 |
("preprocessing", transform),
|
| 37 |
("model", RandomForestClassifier(n_estimators=10, random_state=125)),
|
| 38 |
]
|
|
@@ -46,7 +48,7 @@ predictions = pipe.predict(X_test)
|
|
| 46 |
accuracy = accuracy_score(y_test, predictions)
|
| 47 |
f1 = f1_score(y_test, predictions, average="macro")
|
| 48 |
|
| 49 |
-
print("Accuracy: ", str(round(accuracy, 2)*100)+"%","F1: ", round(f1, 2))
|
| 50 |
|
| 51 |
# Confusion matrix
|
| 52 |
import matplotlib.pyplot as plt
|
|
@@ -64,8 +66,9 @@ with open("./results/metrics.txt", "w") as outfile:
|
|
| 64 |
|
| 65 |
# Save the model
|
| 66 |
import pickle
|
|
|
|
| 67 |
# save the model to disk
|
| 68 |
-
filename =
|
| 69 |
-
pickle.dump(pipe, open(filename,
|
| 70 |
|
| 71 |
-
#sio.dump(pipe, "./model/drug_pipeline.skops")
|
|
|
|
| 17 |
X = drug_df.drop("Drug", axis=1).values
|
| 18 |
y = drug_df.Drug.values
|
| 19 |
|
| 20 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 21 |
+
X, y, test_size=0.3, random_state=125
|
| 22 |
+
)
|
| 23 |
|
| 24 |
# Pipeline
|
| 25 |
cat_col = [1, 2, 3]
|
|
|
|
| 34 |
)
|
| 35 |
|
| 36 |
pipe = Pipeline(
|
| 37 |
+
steps=[
|
| 38 |
("preprocessing", transform),
|
| 39 |
("model", RandomForestClassifier(n_estimators=10, random_state=125)),
|
| 40 |
]
|
|
|
|
| 48 |
accuracy = accuracy_score(y_test, predictions)
|
| 49 |
f1 = f1_score(y_test, predictions, average="macro")
|
| 50 |
|
| 51 |
+
print("Accuracy: ", str(round(accuracy, 2) * 100) + "%", "F1: ", round(f1, 2))
|
| 52 |
|
| 53 |
# Confusion matrix
|
| 54 |
import matplotlib.pyplot as plt
|
|
|
|
| 66 |
|
| 67 |
# Save the model
|
| 68 |
import pickle
|
| 69 |
+
|
| 70 |
# save the model to disk
|
| 71 |
+
filename = "./model/drug_pipeline.sav"
|
| 72 |
+
pickle.dump(pipe, open(filename, "wb"))
|
| 73 |
|
| 74 |
+
# sio.dump(pipe, "./model/drug_pipeline.skops")
|