selva1909 commited on
Commit
33dcb6b
·
verified ·
1 Parent(s): d7696f3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from sklearn import datasets
4
+ from sklearn.neighbors import KNeighborsClassifier
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.metrics import accuracy_score
7
+ import gradio as gr
8
+ import tempfile
9
+ import os
10
+
11
+ # ----------------- GLOBAL VARIABLES -------------------
12
+ X, y = None, None
13
+ X_train, X_test, y_train, y_test = None, None, None, None
14
+
15
+
16
+ def split_dataset(test_ratio):
17
+ global X, y, X_train, X_test, y_train, y_test
18
+
19
+ X, y = datasets.make_blobs(
20
+ n_samples=300,
21
+ centers=3,
22
+ cluster_std=2.0,
23
+ random_state=None
24
+ )
25
+
26
+ X_train, X_test, y_train, y_test = train_test_split(
27
+ X, y, test_size=test_ratio, random_state=None
28
+ )
29
+
30
+ return f"Dataset split successfully!\nTrain size: {len(X_train)}\nTest size: {len(X_test)}"
31
+
32
+
33
+ def visualize_knn(n_neighbors):
34
+ global X_train, X_test, y_train, y_test
35
+
36
+ if X_train is None:
37
+ return None, "⚠ Please click 'Split Dataset' first!"
38
+
39
+ n_neighbors = int(n_neighbors)
40
+ model = KNeighborsClassifier(n_neighbors=n_neighbors)
41
+ model.fit(X_train, y_train)
42
+
43
+ y_pred = model.predict(X_test)
44
+ acc = accuracy_score(y_test, y_pred)
45
+
46
+ x_min, x_max = min(X_train[:, 0].min(), X_test[:, 0].min()) - 1, max(X_train[:, 0].max(), X_test[:, 0].max()) + 1
47
+ y_min, y_max = min(X_train[:, 1].min(), X_test[:, 1].min()) - 1, max(X_train[:, 1].max(), X_test[:, 1].max()) + 1
48
+
49
+ xx, yy = np.meshgrid(
50
+ np.linspace(x_min, x_max, 300),
51
+ np.linspace(y_min, y_max, 300)
52
+ )
53
+
54
+ Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
55
+ Z = Z.reshape(xx.shape)
56
+
57
+ plt.figure(figsize=(7, 7))
58
+ plt.contourf(xx, yy, Z, alpha=0.4, cmap="Accent")
59
+ plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap="Accent", edgecolors="black", marker="o")
60
+ plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap="Accent", edgecolors="black", marker="^")
61
+ plt.title(f"KNN Decision Boundary (k = {n_neighbors})")
62
+
63
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
64
+ plt.savefig(temp_file.name)
65
+ plt.close()
66
+
67
+ return temp_file.name, f"Accuracy: {acc:.4f}"
68
+
69
+
70
+ custom_css = """
71
+ .gr-button {
72
+ background-color: #007bff !important;
73
+ color: white !important;
74
+ border-radius: 8px !important;
75
+ padding: 12px 20px !important;
76
+ font-weight: bold !important;
77
+ }
78
+ .gr-slider input {
79
+ accent-color: #007bff !important;
80
+ }
81
+ body, .gradio-container {
82
+ background: #1f1f1f !important;
83
+ color: white !important;
84
+ }
85
+ .gr-box, .gr-textbox, .gr-markdown {
86
+ color: white !important;
87
+ }
88
+ """
89
+
90
+ with gr.Blocks(css=custom_css) as demo:
91
+
92
+ gr.Markdown("## 🧠 KNN Decision Boundary + Dynamic Train/Test Split Visualizer")
93
+
94
+ with gr.Row():
95
+ with gr.Column(scale=1):
96
+ split_ratio = gr.Slider(0.1, 0.5, value=0.3, step=0.05, label="Test Size Ratio")
97
+ split_btn = gr.Button("Split Dataset")
98
+ split_output = gr.Textbox(label="Split Result", interactive=False)
99
+
100
+ k_slider = gr.Slider(1, 20, value=3, step=1, label="K Value (n_neighbors)")
101
+ visualize_btn = gr.Button("Visualize")
102
+
103
+ with gr.Column(scale=2):
104
+ output_img = gr.Image()
105
+ accuracy_text = gr.Textbox(label="Model Accuracy", interactive=False)
106
+
107
+ split_btn.click(split_dataset, inputs=[split_ratio], outputs=[split_output])
108
+ visualize_btn.click(visualize_knn, inputs=[k_slider], outputs=[output_img, accuracy_text])
109
+
110
+
111
+ demo.launch(server_name="0.0.0.0", server_port=7860)