Hem345 commited on
Commit
3891746
·
verified ·
1 Parent(s): 4fdac11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  import numpy as np
3
  from sklearn.neighbors import KNeighborsClassifier
 
4
 
5
  def get_user_data_train():
6
  data_points = []
@@ -33,6 +34,23 @@ def knn_classification(X_train, y_train, X_test, k_value):
33
  predictions = knn_classifier.predict(X_test)
34
  return predictions
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def main():
37
  st.title("K-Nearest Neighbor Classification App")
38
 
@@ -48,6 +66,9 @@ def main():
48
  # Perform k-nearest neighbor classification
49
  predictions = knn_classification(X_train, y_train, X_test, k_value)
50
 
 
 
 
51
  # Display results
52
  st.subheader("Results:")
53
  st.write("User-defined Data Points for Testing:")
 
1
  import streamlit as st
2
  import numpy as np
3
  from sklearn.neighbors import KNeighborsClassifier
4
+ import matplotlib.pyplot as plt
5
 
6
  def get_user_data_train():
7
  data_points = []
 
34
  predictions = knn_classifier.predict(X_test)
35
  return predictions
36
 
37
+ def plot_training_and_test_data(X_train, y_train, X_test, predictions):
38
+ unique_labels = np.unique(y_train)
39
+
40
+ # Plot training data
41
+ for label in unique_labels:
42
+ indices = np.where(y_train == label)
43
+ plt.scatter(X_train[indices, 0], X_train[indices, 1], label=f'Training ({label})')
44
+
45
+ # Plot test data with predicted labels
46
+ plt.scatter(X_test[:, 0], X_test[:, 1], label=f'Test (Predicted Labels)', marker='x', c=predictions)
47
+
48
+ plt.xlabel('X-coordinate')
49
+ plt.ylabel('Y-coordinate')
50
+ plt.title('Training and Test Data with Predicted Labels')
51
+ plt.legend()
52
+ st.pyplot()
53
+
54
  def main():
55
  st.title("K-Nearest Neighbor Classification App")
56
 
 
66
  # Perform k-nearest neighbor classification
67
  predictions = knn_classification(X_train, y_train, X_test, k_value)
68
 
69
+ # Plot training and test data
70
+ plot_training_and_test_data(X_train, y_train, X_test, predictions)
71
+
72
  # Display results
73
  st.subheader("Results:")
74
  st.write("User-defined Data Points for Testing:")