knn-vis / app.py
rinabuoy's picture
init
3d53f11
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, FancyBboxPatch, ConnectionPatch
import io
from PIL import Image
from matplotlib.patches import FancyArrowPatch
from scipy.spatial.distance import euclidean, cityblock, chebyshev
class KNNVisualizer:
def __init__(self):
self.reset_data()
self.test_point = None
def reset_data(self):
"""Generate sample data with three classes"""
np.random.seed(42)
# Class 0 (blue) - bottom left
n_samples = 30
self.X0 = np.random.randn(n_samples, 2) * 1.2 + np.array([3, 3])
# Class 1 (red) - top right
self.X1 = np.random.randn(n_samples, 2) * 1.2 + np.array([7, 7])
# Class 2 (green) - top left
self.X2 = np.random.randn(n_samples, 2) * 1.2 + np.array([3, 7])
self.X = np.vstack([self.X0, self.X1, self.X2])
self.y = np.hstack([np.zeros(n_samples), np.ones(n_samples), np.full(n_samples, 2)])
self.test_point = np.array([5.0, 5.0])
def calculate_distance(self, point1, point2, metric='euclidean'):
"""Calculate distance between two points using specified metric"""
if metric == 'euclidean':
return euclidean(point1, point2)
elif metric == 'manhattan':
return cityblock(point1, point2)
elif metric == 'chebyshev':
return chebyshev(point1, point2)
else:
return euclidean(point1, point2)
def find_k_nearest_neighbors(self, test_point, k, metric='euclidean'):
"""Find k nearest neighbors to the test point"""
distances = []
for i, point in enumerate(self.X):
dist = self.calculate_distance(test_point, point, metric)
distances.append((i, dist, self.y[i]))
# Sort by distance
distances.sort(key=lambda x: x[1])
# Return k nearest neighbors
return distances[:k]
def predict_class(self, neighbors):
"""Predict class based on majority vote from neighbors"""
classes = [int(n[2]) for n in neighbors]
class_counts = np.bincount(classes)
return np.argmax(class_counts), class_counts
def visualize(self, test_x, test_y, k_value, distance_metric, show_all_distances):
"""Create comprehensive KNN visualization"""
fig = plt.figure(figsize=(20, 12))
gs = fig.add_gridspec(2, 2, height_ratios=[1.2, 1], width_ratios=[1.5, 1])
ax1 = fig.add_subplot(gs[0, 0]) # Main KNN visualization
ax2 = fig.add_subplot(gs[1, 0]) # Distance calculations table
ax3 = fig.add_subplot(gs[:, 1]) # Statistics and breakdown
# Parse inputs
try:
test_point = np.array([float(test_x), float(test_y)])
k = int(k_value)
k = max(1, min(k, len(self.X))) # Ensure k is valid
except:
test_point = np.array([5.0, 5.0])
k = 5
# Find k nearest neighbors
neighbors = self.find_k_nearest_neighbors(test_point, k, distance_metric)
predicted_class, class_counts = self.predict_class(neighbors)
# === Plot 1: Main KNN Visualization ===
ax1.set_facecolor('#f0f0f0')
# Define colors for classes
class_colors = ['blue', 'red', 'green']
class_names = ['Class 0 (Blue)', 'Class 1 (Red)', 'Class 2 (Green)']
# Plot all training points
for class_idx in range(3):
mask = self.y == class_idx
ax1.scatter(self.X[mask, 0], self.X[mask, 1],
c=class_colors[class_idx], label=class_names[class_idx],
s=100, alpha=0.6, edgecolors='k', linewidths=1.5)
# Highlight k nearest neighbors with larger markers
neighbor_indices = [n[0] for n in neighbors]
neighbor_distances = [n[1] for n in neighbors]
for idx, (n_idx, dist, n_class) in enumerate(neighbors):
point = self.X[n_idx]
# Draw circle around neighbor
circle = Circle(point, 0.3, color=class_colors[int(n_class)],
fill=False, linewidth=3, linestyle='--', alpha=0.8)
ax1.add_patch(circle)
# Draw line from test point to neighbor
if show_all_distances or idx < 10: # Show lines for top 10 or all if selected
ax1.plot([test_point[0], point[0]], [test_point[1], point[1]],
'k--', alpha=0.3, linewidth=1)
# Add distance label
mid_x = (test_point[0] + point[0]) / 2
mid_y = (test_point[1] + point[1]) / 2
ax1.text(mid_x, mid_y, f'{dist:.2f}',
fontsize=8, bbox=dict(boxstyle='round,pad=0.3',
facecolor='yellow', alpha=0.7))
# Plot test point with star marker
ax1.scatter(test_point[0], test_point[1], c=class_colors[predicted_class],
marker='*', s=800, edgecolors='black', linewidths=3,
label=f'Test Point (Predicted: Class {predicted_class})',
zorder=100)
# Draw decision boundary circle (radius = distance to k-th neighbor)
max_neighbor_dist = neighbors[-1][1]
boundary_circle = Circle(test_point, max_neighbor_dist,
color='purple', fill=False, linewidth=2.5,
linestyle=':', alpha=0.6,
label=f'Decision Boundary (r={max_neighbor_dist:.2f})')
ax1.add_patch(boundary_circle)
# Add grid and labels
ax1.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
ax1.set_xlabel('Feature 1 (X)', fontsize=14, fontweight='bold')
ax1.set_ylabel('Feature 2 (Y)', fontsize=14, fontweight='bold')
ax1.set_title(f'K-Nearest Neighbors (k={k}, metric={distance_metric})',
fontsize=16, fontweight='bold')
ax1.legend(fontsize=10, loc='upper left', framealpha=0.9)
ax1.set_xlim(-1, 11)
ax1.set_ylim(-1, 11)
# === Plot 2: Distance Calculations Table ===
ax2.axis('off')
# Prepare table data
table_data = []
table_data.append(['Rank', 'Index', 'X', 'Y', 'Class', 'Distance', 'Neighbor?'])
# Calculate all distances for comparison
all_distances = []
for i, point in enumerate(self.X):
dist = self.calculate_distance(test_point, point, distance_metric)
all_distances.append((i, dist, self.y[i]))
all_distances.sort(key=lambda x: x[1])
# Show top 15 closest points
display_count = min(15, len(all_distances))
for rank, (idx, dist, point_class) in enumerate(all_distances[:display_count], 1):
point = self.X[idx]
is_neighbor = 'βœ“' if rank <= k else ''
row = [
f'{rank}',
f'{idx}',
f'{point[0]:.2f}',
f'{point[1]:.2f}',
f'{int(point_class)}',
f'{dist:.3f}',
is_neighbor
]
table_data.append(row)
# Create table
table = ax2.table(cellText=table_data, cellLoc='center', loc='center',
bbox=[0, 0, 1, 1])
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1, 2)
# Style header row
for i in range(7):
cell = table[(0, i)]
cell.set_facecolor('#4CAF50')
cell.set_text_props(weight='bold', color='white')
# Style data rows
for i in range(1, len(table_data)):
# Highlight neighbors
if i <= k:
for j in range(7):
table[(i, j)].set_facecolor('#E8F5E9')
# Color code by class
class_col = int(table_data[i][4])
table[(i, 4)].set_facecolor(class_colors[class_col])
table[(i, 4)].set_alpha(0.3)
ax2.set_title('Distance Calculations (Sorted by Distance)',
fontsize=14, fontweight='bold', pad=20)
# === Plot 3: Statistics and Algorithm Breakdown ===
ax3.axis('off')
stats_text = "K-NEAREST NEIGHBORS ALGORITHM\n"
stats_text += "="*60 + "\n\n"
stats_text += f"TEST POINT COORDINATES:\n"
stats_text += f" β€’ X: {test_point[0]:.2f}\n"
stats_text += f" β€’ Y: {test_point[1]:.2f}\n\n"
stats_text += f"ALGORITHM PARAMETERS:\n"
stats_text += f" β€’ K value: {k}\n"
stats_text += f" β€’ Distance metric: {distance_metric.upper()}\n"
stats_text += f" β€’ Total training samples: {len(self.X)}\n\n"
# Distance metric explanation
stats_text += f"DISTANCE METRIC: {distance_metric.upper()}\n"
stats_text += "-"*60 + "\n"
if distance_metric == 'euclidean':
stats_text += "Formula: d = √[(xβ‚‚-x₁)Β² + (yβ‚‚-y₁)Β²]\n"
stats_text += " β€’ Standard straight-line distance\n"
stats_text += " β€’ Most commonly used metric\n"
elif distance_metric == 'manhattan':
stats_text += "Formula: d = |xβ‚‚-x₁| + |yβ‚‚-y₁|\n"
stats_text += " β€’ Also called 'City Block' distance\n"
stats_text += " β€’ Sum of absolute differences\n"
elif distance_metric == 'chebyshev':
stats_text += "Formula: d = max(|xβ‚‚-x₁|, |yβ‚‚-y₁|)\n"
stats_text += " β€’ Maximum absolute difference\n"
stats_text += " β€’ Chess king's move distance\n"
stats_text += "\n"
stats_text += f"K NEAREST NEIGHBORS FOUND:\n"
stats_text += "-"*60 + "\n"
for rank, (idx, dist, point_class) in enumerate(neighbors, 1):
point = self.X[idx]
stats_text += f"\n{rank}. Point #{idx} (Class {int(point_class)})\n"
stats_text += f" Position: ({point[0]:.2f}, {point[1]:.2f})\n"
stats_text += f" Distance: {dist:.4f}\n"
# Show calculation for first 3 neighbors
if rank <= 3:
if distance_metric == 'euclidean':
dx = point[0] - test_point[0]
dy = point[1] - test_point[1]
stats_text += f" Calculation: √[({dx:.2f})² + ({dy:.2f})²]\n"
stats_text += f" = √[{dx**2:.2f} + {dy**2:.2f}]\n"
stats_text += f" = {dist:.4f}\n"
elif distance_metric == 'manhattan':
dx = abs(point[0] - test_point[0])
dy = abs(point[1] - test_point[1])
stats_text += f" Calculation: |{dx:.2f}| + |{dy:.2f}|\n"
stats_text += f" = {dist:.4f}\n"
stats_text += "\n\nCLASS DISTRIBUTION IN K NEIGHBORS:\n"
stats_text += "-"*60 + "\n"
for class_idx in range(3):
count = class_counts[class_idx] if class_idx < len(class_counts) else 0
percentage = (count / k) * 100
bar = 'β–ˆ' * int(percentage / 5)
stats_text += f"Class {class_idx}: {count}/{k} ({percentage:.1f}%) {bar}\n"
stats_text += f"\n\nPREDICTION RESULT:\n"
stats_text += "="*60 + "\n"
stats_text += f" β†’ Predicted Class: {predicted_class}\n"
stats_text += f" β†’ Confidence: {class_counts[predicted_class]}/{k} neighbors\n"
stats_text += f" β†’ Percentage: {(class_counts[predicted_class]/k)*100:.1f}%\n\n"
stats_text += "ALGORITHM STEPS:\n"
stats_text += "-"*60 + "\n"
stats_text += "1. Calculate distance from test point to all\n"
stats_text += " training points using selected metric\n"
stats_text += "2. Sort all points by distance (ascending)\n"
stats_text += "3. Select the K nearest points\n"
stats_text += "4. Count class labels among K neighbors\n"
stats_text += "5. Predict class with majority vote\n"
ax3.text(0.05, 0.95, stats_text, transform=ax3.transAxes,
fontsize=9, verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8),
family='monospace')
plt.tight_layout()
# Convert to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
plt.close()
return img
# Create the visualizer
knn_viz = KNNVisualizer()
# Create Gradio interface
with gr.Blocks(title="K-Nearest Neighbors (KNN) Visualizer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎯 Interactive K-Nearest Neighbors (KNN) Algorithm Visualizer
Explore how KNN algorithm works by visualizing distance calculations and neighbor identification!
**Instructions:**
1. Set the test point coordinates (X, Y)
2. Choose the number of neighbors (K)
3. Select a distance metric
4. Click "Update Visualization" to see the results
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Test Point Configuration")
test_x = gr.Slider(minimum=-1, maximum=11, value=5.0, step=0.1,
label="Test Point X Coordinate")
test_y = gr.Slider(minimum=-1, maximum=11, value=5.0, step=0.1,
label="Test Point Y Coordinate")
gr.Markdown("### KNN Parameters")
k_value = gr.Slider(minimum=1, maximum=20, value=5, step=1,
label="K (Number of Neighbors)")
distance_metric = gr.Radio(
choices=['euclidean', 'manhattan', 'chebyshev'],
value='euclidean',
label="Distance Metric"
)
show_all_distances = gr.Checkbox(
value=False,
label="Show all distance lines (may be cluttered)"
)
update_btn = gr.Button("πŸ”„ Update Visualization", variant="primary", size="lg")
gr.Markdown("""
### Distance Metrics:
- **Euclidean**: Standard straight-line distance
- **Manhattan**: Sum of absolute differences (city block)
- **Chebyshev**: Maximum absolute difference
### Try These Examples:
- **Test Point (5, 5), K=5**: See balanced classification
- **Test Point (2, 2), K=3**: Point near Class 0
- **Test Point (8, 8), K=7**: Point near Class 1
- **Different K values**: See how it affects prediction
""")
with gr.Column(scale=2):
output_image = gr.Image(label="KNN Visualization", height=900)
# Update visualization
update_btn.click(
fn=knn_viz.visualize,
inputs=[test_x, test_y, k_value, distance_metric, show_all_distances],
outputs=output_image
)
# Also update on slider/radio change
for component in [test_x, test_y, k_value, distance_metric, show_all_distances]:
component.change(
fn=knn_viz.visualize,
inputs=[test_x, test_y, k_value, distance_metric, show_all_distances],
outputs=output_image
)
# Initial visualization
demo.load(
fn=knn_viz.visualize,
inputs=[test_x, test_y, k_value, distance_metric, show_all_distances],
outputs=output_image
)
# Launch the app
if __name__ == "__main__":
demo.launch()