| import gradio as gr |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from matplotlib.patches import Circle, FancyBboxPatch, ConnectionPatch |
| import io |
| from PIL import Image |
| from matplotlib.patches import FancyArrowPatch |
| from scipy.spatial.distance import euclidean, cityblock, chebyshev |
|
|
| class KNNVisualizer: |
| def __init__(self): |
| self.reset_data() |
| self.test_point = None |
| |
| def reset_data(self): |
| """Generate sample data with three classes""" |
| np.random.seed(42) |
| |
| n_samples = 30 |
| self.X0 = np.random.randn(n_samples, 2) * 1.2 + np.array([3, 3]) |
| |
| self.X1 = np.random.randn(n_samples, 2) * 1.2 + np.array([7, 7]) |
| |
| self.X2 = np.random.randn(n_samples, 2) * 1.2 + np.array([3, 7]) |
| |
| self.X = np.vstack([self.X0, self.X1, self.X2]) |
| self.y = np.hstack([np.zeros(n_samples), np.ones(n_samples), np.full(n_samples, 2)]) |
| self.test_point = np.array([5.0, 5.0]) |
| |
| def calculate_distance(self, point1, point2, metric='euclidean'): |
| """Calculate distance between two points using specified metric""" |
| if metric == 'euclidean': |
| return euclidean(point1, point2) |
| elif metric == 'manhattan': |
| return cityblock(point1, point2) |
| elif metric == 'chebyshev': |
| return chebyshev(point1, point2) |
| else: |
| return euclidean(point1, point2) |
| |
| def find_k_nearest_neighbors(self, test_point, k, metric='euclidean'): |
| """Find k nearest neighbors to the test point""" |
| distances = [] |
| for i, point in enumerate(self.X): |
| dist = self.calculate_distance(test_point, point, metric) |
| distances.append((i, dist, self.y[i])) |
| |
| |
| distances.sort(key=lambda x: x[1]) |
| |
| |
| return distances[:k] |
| |
| def predict_class(self, neighbors): |
| """Predict class based on majority vote from neighbors""" |
| classes = [int(n[2]) for n in neighbors] |
| class_counts = np.bincount(classes) |
| return np.argmax(class_counts), class_counts |
| |
| def visualize(self, test_x, test_y, k_value, distance_metric, show_all_distances): |
| """Create comprehensive KNN visualization""" |
| fig = plt.figure(figsize=(20, 12)) |
| gs = fig.add_gridspec(2, 2, height_ratios=[1.2, 1], width_ratios=[1.5, 1]) |
| |
| ax1 = fig.add_subplot(gs[0, 0]) |
| ax2 = fig.add_subplot(gs[1, 0]) |
| ax3 = fig.add_subplot(gs[:, 1]) |
| |
| |
| try: |
| test_point = np.array([float(test_x), float(test_y)]) |
| k = int(k_value) |
| k = max(1, min(k, len(self.X))) |
| except: |
| test_point = np.array([5.0, 5.0]) |
| k = 5 |
| |
| |
| neighbors = self.find_k_nearest_neighbors(test_point, k, distance_metric) |
| predicted_class, class_counts = self.predict_class(neighbors) |
| |
| |
| ax1.set_facecolor('#f0f0f0') |
| |
| |
| class_colors = ['blue', 'red', 'green'] |
| class_names = ['Class 0 (Blue)', 'Class 1 (Red)', 'Class 2 (Green)'] |
| |
| |
| for class_idx in range(3): |
| mask = self.y == class_idx |
| ax1.scatter(self.X[mask, 0], self.X[mask, 1], |
| c=class_colors[class_idx], label=class_names[class_idx], |
| s=100, alpha=0.6, edgecolors='k', linewidths=1.5) |
| |
| |
| neighbor_indices = [n[0] for n in neighbors] |
| neighbor_distances = [n[1] for n in neighbors] |
| |
| for idx, (n_idx, dist, n_class) in enumerate(neighbors): |
| point = self.X[n_idx] |
| |
| circle = Circle(point, 0.3, color=class_colors[int(n_class)], |
| fill=False, linewidth=3, linestyle='--', alpha=0.8) |
| ax1.add_patch(circle) |
| |
| |
| if show_all_distances or idx < 10: |
| ax1.plot([test_point[0], point[0]], [test_point[1], point[1]], |
| 'k--', alpha=0.3, linewidth=1) |
| |
| |
| mid_x = (test_point[0] + point[0]) / 2 |
| mid_y = (test_point[1] + point[1]) / 2 |
| ax1.text(mid_x, mid_y, f'{dist:.2f}', |
| fontsize=8, bbox=dict(boxstyle='round,pad=0.3', |
| facecolor='yellow', alpha=0.7)) |
| |
| |
| ax1.scatter(test_point[0], test_point[1], c=class_colors[predicted_class], |
| marker='*', s=800, edgecolors='black', linewidths=3, |
| label=f'Test Point (Predicted: Class {predicted_class})', |
| zorder=100) |
| |
| |
| max_neighbor_dist = neighbors[-1][1] |
| boundary_circle = Circle(test_point, max_neighbor_dist, |
| color='purple', fill=False, linewidth=2.5, |
| linestyle=':', alpha=0.6, |
| label=f'Decision Boundary (r={max_neighbor_dist:.2f})') |
| ax1.add_patch(boundary_circle) |
| |
| |
| ax1.grid(True, alpha=0.3, linestyle='--', linewidth=0.5) |
| ax1.set_xlabel('Feature 1 (X)', fontsize=14, fontweight='bold') |
| ax1.set_ylabel('Feature 2 (Y)', fontsize=14, fontweight='bold') |
| ax1.set_title(f'K-Nearest Neighbors (k={k}, metric={distance_metric})', |
| fontsize=16, fontweight='bold') |
| ax1.legend(fontsize=10, loc='upper left', framealpha=0.9) |
| ax1.set_xlim(-1, 11) |
| ax1.set_ylim(-1, 11) |
| |
| |
| ax2.axis('off') |
| |
| |
| table_data = [] |
| table_data.append(['Rank', 'Index', 'X', 'Y', 'Class', 'Distance', 'Neighbor?']) |
| |
| |
| all_distances = [] |
| for i, point in enumerate(self.X): |
| dist = self.calculate_distance(test_point, point, distance_metric) |
| all_distances.append((i, dist, self.y[i])) |
| all_distances.sort(key=lambda x: x[1]) |
| |
| |
| display_count = min(15, len(all_distances)) |
| for rank, (idx, dist, point_class) in enumerate(all_distances[:display_count], 1): |
| point = self.X[idx] |
| is_neighbor = 'β' if rank <= k else '' |
| |
| row = [ |
| f'{rank}', |
| f'{idx}', |
| f'{point[0]:.2f}', |
| f'{point[1]:.2f}', |
| f'{int(point_class)}', |
| f'{dist:.3f}', |
| is_neighbor |
| ] |
| table_data.append(row) |
| |
| |
| table = ax2.table(cellText=table_data, cellLoc='center', loc='center', |
| bbox=[0, 0, 1, 1]) |
| table.auto_set_font_size(False) |
| table.set_fontsize(9) |
| table.scale(1, 2) |
| |
| |
| for i in range(7): |
| cell = table[(0, i)] |
| cell.set_facecolor('#4CAF50') |
| cell.set_text_props(weight='bold', color='white') |
| |
| |
| for i in range(1, len(table_data)): |
| |
| if i <= k: |
| for j in range(7): |
| table[(i, j)].set_facecolor('#E8F5E9') |
| |
| |
| class_col = int(table_data[i][4]) |
| table[(i, 4)].set_facecolor(class_colors[class_col]) |
| table[(i, 4)].set_alpha(0.3) |
| |
| ax2.set_title('Distance Calculations (Sorted by Distance)', |
| fontsize=14, fontweight='bold', pad=20) |
| |
| |
| ax3.axis('off') |
| |
| stats_text = "K-NEAREST NEIGHBORS ALGORITHM\n" |
| stats_text += "="*60 + "\n\n" |
| |
| stats_text += f"TEST POINT COORDINATES:\n" |
| stats_text += f" β’ X: {test_point[0]:.2f}\n" |
| stats_text += f" β’ Y: {test_point[1]:.2f}\n\n" |
| |
| stats_text += f"ALGORITHM PARAMETERS:\n" |
| stats_text += f" β’ K value: {k}\n" |
| stats_text += f" β’ Distance metric: {distance_metric.upper()}\n" |
| stats_text += f" β’ Total training samples: {len(self.X)}\n\n" |
| |
| |
| stats_text += f"DISTANCE METRIC: {distance_metric.upper()}\n" |
| stats_text += "-"*60 + "\n" |
| if distance_metric == 'euclidean': |
| stats_text += "Formula: d = β[(xβ-xβ)Β² + (yβ-yβ)Β²]\n" |
| stats_text += " β’ Standard straight-line distance\n" |
| stats_text += " β’ Most commonly used metric\n" |
| elif distance_metric == 'manhattan': |
| stats_text += "Formula: d = |xβ-xβ| + |yβ-yβ|\n" |
| stats_text += " β’ Also called 'City Block' distance\n" |
| stats_text += " β’ Sum of absolute differences\n" |
| elif distance_metric == 'chebyshev': |
| stats_text += "Formula: d = max(|xβ-xβ|, |yβ-yβ|)\n" |
| stats_text += " β’ Maximum absolute difference\n" |
| stats_text += " β’ Chess king's move distance\n" |
| stats_text += "\n" |
| |
| stats_text += f"K NEAREST NEIGHBORS FOUND:\n" |
| stats_text += "-"*60 + "\n" |
| for rank, (idx, dist, point_class) in enumerate(neighbors, 1): |
| point = self.X[idx] |
| stats_text += f"\n{rank}. Point #{idx} (Class {int(point_class)})\n" |
| stats_text += f" Position: ({point[0]:.2f}, {point[1]:.2f})\n" |
| stats_text += f" Distance: {dist:.4f}\n" |
| |
| |
| if rank <= 3: |
| if distance_metric == 'euclidean': |
| dx = point[0] - test_point[0] |
| dy = point[1] - test_point[1] |
| stats_text += f" Calculation: β[({dx:.2f})Β² + ({dy:.2f})Β²]\n" |
| stats_text += f" = β[{dx**2:.2f} + {dy**2:.2f}]\n" |
| stats_text += f" = {dist:.4f}\n" |
| elif distance_metric == 'manhattan': |
| dx = abs(point[0] - test_point[0]) |
| dy = abs(point[1] - test_point[1]) |
| stats_text += f" Calculation: |{dx:.2f}| + |{dy:.2f}|\n" |
| stats_text += f" = {dist:.4f}\n" |
| |
| stats_text += "\n\nCLASS DISTRIBUTION IN K NEIGHBORS:\n" |
| stats_text += "-"*60 + "\n" |
| for class_idx in range(3): |
| count = class_counts[class_idx] if class_idx < len(class_counts) else 0 |
| percentage = (count / k) * 100 |
| bar = 'β' * int(percentage / 5) |
| stats_text += f"Class {class_idx}: {count}/{k} ({percentage:.1f}%) {bar}\n" |
| |
| stats_text += f"\n\nPREDICTION RESULT:\n" |
| stats_text += "="*60 + "\n" |
| stats_text += f" β Predicted Class: {predicted_class}\n" |
| stats_text += f" β Confidence: {class_counts[predicted_class]}/{k} neighbors\n" |
| stats_text += f" β Percentage: {(class_counts[predicted_class]/k)*100:.1f}%\n\n" |
| |
| stats_text += "ALGORITHM STEPS:\n" |
| stats_text += "-"*60 + "\n" |
| stats_text += "1. Calculate distance from test point to all\n" |
| stats_text += " training points using selected metric\n" |
| stats_text += "2. Sort all points by distance (ascending)\n" |
| stats_text += "3. Select the K nearest points\n" |
| stats_text += "4. Count class labels among K neighbors\n" |
| stats_text += "5. Predict class with majority vote\n" |
| |
| ax3.text(0.05, 0.95, stats_text, transform=ax3.transAxes, |
| fontsize=9, verticalalignment='top', |
| bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8), |
| family='monospace') |
| |
| plt.tight_layout() |
| |
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', dpi=120, bbox_inches='tight') |
| buf.seek(0) |
| img = Image.open(buf) |
| plt.close() |
| |
| return img |
|
|
| |
| knn_viz = KNNVisualizer() |
|
|
| |
| with gr.Blocks(title="K-Nearest Neighbors (KNN) Visualizer", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # π― Interactive K-Nearest Neighbors (KNN) Algorithm Visualizer |
| |
| Explore how KNN algorithm works by visualizing distance calculations and neighbor identification! |
| |
| **Instructions:** |
| 1. Set the test point coordinates (X, Y) |
| 2. Choose the number of neighbors (K) |
| 3. Select a distance metric |
| 4. Click "Update Visualization" to see the results |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Test Point Configuration") |
| test_x = gr.Slider(minimum=-1, maximum=11, value=5.0, step=0.1, |
| label="Test Point X Coordinate") |
| test_y = gr.Slider(minimum=-1, maximum=11, value=5.0, step=0.1, |
| label="Test Point Y Coordinate") |
| |
| gr.Markdown("### KNN Parameters") |
| k_value = gr.Slider(minimum=1, maximum=20, value=5, step=1, |
| label="K (Number of Neighbors)") |
| distance_metric = gr.Radio( |
| choices=['euclidean', 'manhattan', 'chebyshev'], |
| value='euclidean', |
| label="Distance Metric" |
| ) |
| show_all_distances = gr.Checkbox( |
| value=False, |
| label="Show all distance lines (may be cluttered)" |
| ) |
| |
| update_btn = gr.Button("π Update Visualization", variant="primary", size="lg") |
| |
| gr.Markdown(""" |
| ### Distance Metrics: |
| - **Euclidean**: Standard straight-line distance |
| - **Manhattan**: Sum of absolute differences (city block) |
| - **Chebyshev**: Maximum absolute difference |
| |
| ### Try These Examples: |
| - **Test Point (5, 5), K=5**: See balanced classification |
| - **Test Point (2, 2), K=3**: Point near Class 0 |
| - **Test Point (8, 8), K=7**: Point near Class 1 |
| - **Different K values**: See how it affects prediction |
| """) |
| |
| with gr.Column(scale=2): |
| output_image = gr.Image(label="KNN Visualization", height=900) |
| |
| |
| update_btn.click( |
| fn=knn_viz.visualize, |
| inputs=[test_x, test_y, k_value, distance_metric, show_all_distances], |
| outputs=output_image |
| ) |
| |
| |
| for component in [test_x, test_y, k_value, distance_metric, show_all_distances]: |
| component.change( |
| fn=knn_viz.visualize, |
| inputs=[test_x, test_y, k_value, distance_metric, show_all_distances], |
| outputs=output_image |
| ) |
| |
| |
| demo.load( |
| fn=knn_viz.visualize, |
| inputs=[test_x, test_y, k_value, distance_metric, show_all_distances], |
| outputs=output_image |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |
|
|
|
|