Spaces:
Sleeping
Sleeping
Create knn.py
Browse files
knn.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import seaborn as sns
|
| 5 |
+
from sklearn.datasets import make_classification, make_moons, make_circles, make_blobs
|
| 6 |
+
from sklearn.model_selection import train_test_split, learning_curve
|
| 7 |
+
from sklearn.neighbors import KNeighborsClassifier
|
| 8 |
+
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score
|
| 9 |
+
from mlxtend.plotting import plot_decision_regions
|
| 10 |
+
|
| 11 |
+
# image
|
| 12 |
+
st.image("https://huggingface.co/spaces/varshitha22/decision_boundary/resolve/main/logo.png")
|
| 13 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 14 |
+
|
| 15 |
+
def plot_learning_curves(X_train, y_train, X_test, y_test, model, scoring='accuracy'):
|
| 16 |
+
train_sizes, train_scores, test_scores = learning_curve(model, X_train, y_train, cv=5, scoring=scoring)
|
| 17 |
+
train_mean = np.mean(train_scores, axis=1)
|
| 18 |
+
test_mean = np.mean(test_scores, axis=1)
|
| 19 |
+
|
| 20 |
+
fig, ax = plt.subplots()
|
| 21 |
+
plt.plot(train_sizes, train_mean, 'o-', color="r", label="Training Score")
|
| 22 |
+
plt.plot(train_sizes, test_mean, 'o-', color="g", label="Cross-validation Score")
|
| 23 |
+
plt.xlabel("Training Examples")
|
| 24 |
+
plt.ylabel("Score")
|
| 25 |
+
plt.legend()
|
| 26 |
+
st.pyplot(fig)
|
| 27 |
+
|
| 28 |
+
# Sidebar for dataset selection
|
| 29 |
+
st.sidebar.header("Dataset Options")
|
| 30 |
+
data_type = st.sidebar.selectbox("Select Data Type:", ["Blobs", "Circles", "Moons", "Classification"])
|
| 31 |
+
noise = st.sidebar.slider("Add Noise:", 0.0, 1.0, 0.2, step=0.05)
|
| 32 |
+
|
| 33 |
+
# Sidebar for model selection
|
| 34 |
+
st.sidebar.header("Model")
|
| 35 |
+
model_name = st.sidebar.radio("Model: ","KNN")
|
| 36 |
+
|
| 37 |
+
# Display number of neighbors selector only if KNN is selected
|
| 38 |
+
if model_name == "KNN":
|
| 39 |
+
neighbors = st.sidebar.number_input("Neighbors", min_value=1, max_value=25, value=5, step=1)
|
| 40 |
+
knn_weights = st.sidebar.radio("KNN Weights:", ["uniform", "distance"])
|
| 41 |
+
|
| 42 |
+
# KNN Algorithm
|
| 43 |
+
st.sidebar.subheader("KNN Algorithm")
|
| 44 |
+
algorithms_selected = []
|
| 45 |
+
if st.sidebar.checkbox("auto", value=True):
|
| 46 |
+
algorithms_selected.append("auto")
|
| 47 |
+
if st.sidebar.checkbox("ball_tree"):
|
| 48 |
+
algorithms_selected.append("ball_tree")
|
| 49 |
+
if st.sidebar.checkbox("kd_tree"):
|
| 50 |
+
algorithms_selected.append("kd_tree")
|
| 51 |
+
if st.sidebar.checkbox("brute"):
|
| 52 |
+
algorithms_selected.append("brute")
|
| 53 |
+
|
| 54 |
+
# KNN Metric
|
| 55 |
+
st.sidebar.subheader("KNN Metric")
|
| 56 |
+
metrics_selected = []
|
| 57 |
+
if st.sidebar.checkbox("euclidean", value=True):
|
| 58 |
+
metrics_selected.append("euclidean")
|
| 59 |
+
if st.sidebar.checkbox("manhattan"):
|
| 60 |
+
metrics_selected.append("manhattan")
|
| 61 |
+
if st.sidebar.checkbox("minkowski"):
|
| 62 |
+
metrics_selected.append("minkowski")
|
| 63 |
+
|
| 64 |
+
# Generate dataset
|
| 65 |
+
if data_type == "Blobs":
|
| 66 |
+
X, y = make_blobs(n_samples=5000, centers=2, cluster_std=noise, random_state=27)
|
| 67 |
+
elif data_type == "Circles":
|
| 68 |
+
X, y = make_circles(n_samples=5000, noise=noise, factor=0.5, random_state=27)
|
| 69 |
+
elif data_type == "Moons":
|
| 70 |
+
X, y = make_moons(n_samples=5000, noise=noise, random_state=27)
|
| 71 |
+
else:
|
| 72 |
+
X, y = make_classification(n_samples=5000, n_features=2, n_classes=2, n_informative=2, n_redundant=0, random_state=27)
|
| 73 |
+
|
| 74 |
+
# Split dataset
|
| 75 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=27)
|
| 76 |
+
|
| 77 |
+
# Model selection
|
| 78 |
+
if model_name == "KNN":
|
| 79 |
+
model = KNeighborsClassifier(n_neighbors=neighbors, weights=knn_weights, algorithm=algorithms_selected[0] if algorithms_selected else 'auto', metric=metrics_selected[0] if metrics_selected else 'minkowski')
|
| 80 |
+
|
| 81 |
+
# Fit the model
|
| 82 |
+
model.fit(X_train, y_train)
|
| 83 |
+
|
| 84 |
+
# Display performance metrics only for KNN
|
| 85 |
+
if model_name == "KNN":
|
| 86 |
+
st.subheader("KNN Model Evaluation Metrics")
|
| 87 |
+
y_pred = model.predict(X_test)
|
| 88 |
+
|
| 89 |
+
# Performance metrics calculation and display
|
| 90 |
+
accuracy = accuracy_score(y_test, y_pred)
|
| 91 |
+
st.write(f"Accuracy: {accuracy:.2f}")
|
| 92 |
+
|
| 93 |
+
precision = precision_score(y_test, y_pred)
|
| 94 |
+
st.write(f"Precision: {precision:.2f}")
|
| 95 |
+
|
| 96 |
+
recall = recall_score(y_test, y_pred)
|
| 97 |
+
st.write(f"Recall: {recall:.2f}")
|
| 98 |
+
|
| 99 |
+
f1 = f1_score(y_test, y_pred)
|
| 100 |
+
st.write(f"F1 Score: {f1:.2f}")
|
| 101 |
+
|
| 102 |
+
auc = roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]) if hasattr(model, "predict_proba") else "N/A"
|
| 103 |
+
st.write(f"AUC Score: {auc:.2f}")
|
| 104 |
+
|
| 105 |
+
# Plot dataset
|
| 106 |
+
st.subheader("Dataset Visualization")
|
| 107 |
+
fig, ax = plt.subplots()
|
| 108 |
+
sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, palette="coolwarm", s=50, edgecolor="k")
|
| 109 |
+
st.pyplot(fig)
|
| 110 |
+
|
| 111 |
+
# Decision Boundary
|
| 112 |
+
st.subheader("Decision Boundary")
|
| 113 |
+
fig, ax = plt.subplots()
|
| 114 |
+
plot_decision_regions(X_train, y_train, clf=model, legend=2)
|
| 115 |
+
st.pyplot(fig)
|
| 116 |
+
|
| 117 |
+
# Learning Curve
|
| 118 |
+
st.subheader("Learning Curve")
|
| 119 |
+
plot_learning_curves(X_train, y_train, X_test, y_test, model, scoring='accuracy')
|