varshitha22 commited on
Commit
ab61b9b
·
verified ·
1 Parent(s): 6d11725

Create knn.py

Browse files
Files changed (1) hide show
  1. knn.py +119 -0
knn.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from sklearn.datasets import make_classification, make_moons, make_circles, make_blobs
6
+ from sklearn.model_selection import train_test_split, learning_curve
7
+ from sklearn.neighbors import KNeighborsClassifier
8
+ from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score
9
+ from mlxtend.plotting import plot_decision_regions
10
+
11
+ # image
12
+ st.image("https://huggingface.co/spaces/varshitha22/decision_boundary/resolve/main/logo.png")
13
+ st.markdown("<br>", unsafe_allow_html=True)
14
+
15
+ def plot_learning_curves(X_train, y_train, X_test, y_test, model, scoring='accuracy'):
16
+ train_sizes, train_scores, test_scores = learning_curve(model, X_train, y_train, cv=5, scoring=scoring)
17
+ train_mean = np.mean(train_scores, axis=1)
18
+ test_mean = np.mean(test_scores, axis=1)
19
+
20
+ fig, ax = plt.subplots()
21
+ plt.plot(train_sizes, train_mean, 'o-', color="r", label="Training Score")
22
+ plt.plot(train_sizes, test_mean, 'o-', color="g", label="Cross-validation Score")
23
+ plt.xlabel("Training Examples")
24
+ plt.ylabel("Score")
25
+ plt.legend()
26
+ st.pyplot(fig)
27
+
28
+ # Sidebar for dataset selection
29
+ st.sidebar.header("Dataset Options")
30
+ data_type = st.sidebar.selectbox("Select Data Type:", ["Blobs", "Circles", "Moons", "Classification"])
31
+ noise = st.sidebar.slider("Add Noise:", 0.0, 1.0, 0.2, step=0.05)
32
+
33
+ # Sidebar for model selection
34
+ st.sidebar.header("Model")
35
+ model_name = st.sidebar.radio("Model: ","KNN")
36
+
37
+ # Display number of neighbors selector only if KNN is selected
38
+ if model_name == "KNN":
39
+ neighbors = st.sidebar.number_input("Neighbors", min_value=1, max_value=25, value=5, step=1)
40
+ knn_weights = st.sidebar.radio("KNN Weights:", ["uniform", "distance"])
41
+
42
+ # KNN Algorithm
43
+ st.sidebar.subheader("KNN Algorithm")
44
+ algorithms_selected = []
45
+ if st.sidebar.checkbox("auto", value=True):
46
+ algorithms_selected.append("auto")
47
+ if st.sidebar.checkbox("ball_tree"):
48
+ algorithms_selected.append("ball_tree")
49
+ if st.sidebar.checkbox("kd_tree"):
50
+ algorithms_selected.append("kd_tree")
51
+ if st.sidebar.checkbox("brute"):
52
+ algorithms_selected.append("brute")
53
+
54
+ # KNN Metric
55
+ st.sidebar.subheader("KNN Metric")
56
+ metrics_selected = []
57
+ if st.sidebar.checkbox("euclidean", value=True):
58
+ metrics_selected.append("euclidean")
59
+ if st.sidebar.checkbox("manhattan"):
60
+ metrics_selected.append("manhattan")
61
+ if st.sidebar.checkbox("minkowski"):
62
+ metrics_selected.append("minkowski")
63
+
64
+ # Generate dataset
65
+ if data_type == "Blobs":
66
+ X, y = make_blobs(n_samples=5000, centers=2, cluster_std=noise, random_state=27)
67
+ elif data_type == "Circles":
68
+ X, y = make_circles(n_samples=5000, noise=noise, factor=0.5, random_state=27)
69
+ elif data_type == "Moons":
70
+ X, y = make_moons(n_samples=5000, noise=noise, random_state=27)
71
+ else:
72
+ X, y = make_classification(n_samples=5000, n_features=2, n_classes=2, n_informative=2, n_redundant=0, random_state=27)
73
+
74
+ # Split dataset
75
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=27)
76
+
77
+ # Model selection
78
+ if model_name == "KNN":
79
+ model = KNeighborsClassifier(n_neighbors=neighbors, weights=knn_weights, algorithm=algorithms_selected[0] if algorithms_selected else 'auto', metric=metrics_selected[0] if metrics_selected else 'minkowski')
80
+
81
+ # Fit the model
82
+ model.fit(X_train, y_train)
83
+
84
+ # Display performance metrics only for KNN
85
+ if model_name == "KNN":
86
+ st.subheader("KNN Model Evaluation Metrics")
87
+ y_pred = model.predict(X_test)
88
+
89
+ # Performance metrics calculation and display
90
+ accuracy = accuracy_score(y_test, y_pred)
91
+ st.write(f"Accuracy: {accuracy:.2f}")
92
+
93
+ precision = precision_score(y_test, y_pred)
94
+ st.write(f"Precision: {precision:.2f}")
95
+
96
+ recall = recall_score(y_test, y_pred)
97
+ st.write(f"Recall: {recall:.2f}")
98
+
99
+ f1 = f1_score(y_test, y_pred)
100
+ st.write(f"F1 Score: {f1:.2f}")
101
+
102
+ auc = roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]) if hasattr(model, "predict_proba") else "N/A"
103
+ st.write(f"AUC Score: {auc:.2f}")
104
+
105
+ # Plot dataset
106
+ st.subheader("Dataset Visualization")
107
+ fig, ax = plt.subplots()
108
+ sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, palette="coolwarm", s=50, edgecolor="k")
109
+ st.pyplot(fig)
110
+
111
+ # Decision Boundary
112
+ st.subheader("Decision Boundary")
113
+ fig, ax = plt.subplots()
114
+ plot_decision_regions(X_train, y_train, clf=model, legend=2)
115
+ st.pyplot(fig)
116
+
117
+ # Learning Curve
118
+ st.subheader("Learning Curve")
119
+ plot_learning_curves(X_train, y_train, X_test, y_test, model, scoring='accuracy')