decision-tree / app.py
rinabuoy's picture
init
2283bc3
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, FancyBboxPatch
import io
from PIL import Image
from matplotlib.patches import FancyArrowPatch
class TreeNode:
"""Represents a node in the decision tree"""
def __init__(self, depth=0, bounds=None):
self.depth = depth
self.bounds = bounds if bounds else {'x': (0, 10), 'y': (0, 10)}
self.feature = None # 'x' or 'y'
self.threshold = None
self.left = None
self.right = None
self.is_leaf = True
self.samples = None
self.class_counts = None
self.entropy = None
self.gini = None
self.majority_class = None
class DecisionTreePartitioner:
def __init__(self):
self.reset_data()
self.splits = [] # List of (feature, threshold) tuples
self.root = None
def reset_data(self):
"""Generate sample data with two classes"""
np.random.seed(42)
# Class 0 (blue) - bottom left
n_samples = 50
self.X0 = np.random.randn(n_samples, 2) * 1.5 + np.array([3, 3])
# Class 1 (red) - top right
self.X1 = np.random.randn(n_samples, 2) * 1.5 + np.array([7, 7])
self.X = np.vstack([self.X0, self.X1])
self.y = np.hstack([np.zeros(n_samples), np.ones(n_samples)])
self.splits = []
self.root = None
def calculate_entropy(self, y):
"""Calculate entropy for a set of labels"""
if len(y) == 0:
return 0
_, counts = np.unique(y, return_counts=True)
probabilities = counts / len(y)
entropy = -np.sum(probabilities * np.log2(probabilities + 1e-10))
return entropy
def calculate_gini(self, y):
"""Calculate Gini index for a set of labels"""
if len(y) == 0:
return 0
_, counts = np.unique(y, return_counts=True)
probabilities = counts / len(y)
gini = 1 - np.sum(probabilities ** 2)
return gini
def build_tree_from_splits(self):
"""Build tree structure from the list of splits"""
if not self.splits:
return None
self.root = TreeNode(depth=0)
self._build_node(self.root, np.arange(len(self.y)), 0)
return self.root
def _build_node(self, node, indices, split_idx):
"""Recursively build tree nodes"""
if len(indices) == 0:
return
# Calculate node statistics
node.samples = len(indices)
y_node = self.y[indices]
unique, counts = np.unique(y_node, return_counts=True)
node.class_counts = dict(zip(unique.astype(int), counts))
node.entropy = self.calculate_entropy(y_node)
node.gini = self.calculate_gini(y_node)
node.majority_class = int(unique[np.argmax(counts)])
# Check if we have more splits to apply
if split_idx >= len(self.splits):
node.is_leaf = True
return
# Apply the split
feature, threshold = self.splits[split_idx]
feature_idx = 0 if feature == 'x' else 1
X_node = self.X[indices]
left_mask = X_node[:, feature_idx] <= threshold
right_mask = ~left_mask
left_indices = indices[left_mask]
right_indices = indices[right_mask]
# Only create split if both children are non-empty
if len(left_indices) > 0 and len(right_indices) > 0:
node.is_leaf = False
node.feature = feature
node.threshold = threshold
# Create child nodes with updated bounds
left_bounds = node.bounds.copy()
right_bounds = node.bounds.copy()
if feature == 'x':
left_bounds['x'] = (node.bounds['x'][0], threshold)
right_bounds['x'] = (threshold, node.bounds['x'][1])
else:
left_bounds['y'] = (node.bounds['y'][0], threshold)
right_bounds['y'] = (threshold, node.bounds['y'][1])
node.left = TreeNode(depth=node.depth + 1, bounds=left_bounds)
node.right = TreeNode(depth=node.depth + 1, bounds=right_bounds)
# Recursively build children
self._build_node(node.left, left_indices, split_idx + 1)
self._build_node(node.right, right_indices, split_idx + 1)
def add_split(self, feature, threshold):
"""Add a new split to the tree"""
self.splits.append((feature, threshold))
self.build_tree_from_splits()
def remove_last_split(self):
"""Remove the last split"""
if self.splits:
self.splits.pop()
if self.splits:
self.build_tree_from_splits()
else:
self.root = None
def draw_tree(self, node=None, ax=None, x=0.5, y=1.0, dx=0.25, level=0):
"""Recursively draw the decision tree"""
if node is None:
return
# Node styling
if node.is_leaf:
box_color = 'lightblue' if node.majority_class == 0 else 'orange'
alpha = 0.7
else:
box_color = 'lightgreen'
alpha = 0.5
# Create node text
if node.is_leaf:
text = f"Leaf\nClass: {node.majority_class}\n"
text += f"Samples: {node.samples}\n"
text += f"Entropy: {node.entropy:.3f}\n"
text += f"Gini: {node.gini:.3f}"
else:
feature_name = "Width" if node.feature == 'x' else "Height"
text = f"{feature_name}{node.threshold:.2f}\n"
text += f"Samples: {node.samples}\n"
text += f"Entropy: {node.entropy:.3f}\n"
text += f"Gini: {node.gini:.3f}"
# Draw box
bbox = dict(boxstyle="round,pad=0.3", facecolor=box_color,
edgecolor='black', linewidth=2, alpha=alpha)
ax.text(x, y, text, ha='center', va='center', fontsize=8,
bbox=bbox, fontweight='bold')
# Draw connections to children
if not node.is_leaf and node.left and node.right:
# Left child
y_child = y - 0.15
x_left = x - dx
x_right = x + dx
# Draw arrows
arrow_left = FancyArrowPatch((x, y - 0.05), (x_left, y_child + 0.05),
arrowstyle='->', mutation_scale=20,
linewidth=2, color='blue')
arrow_right = FancyArrowPatch((x, y - 0.05), (x_right, y_child + 0.05),
arrowstyle='->', mutation_scale=20,
linewidth=2, color='red')
ax.add_patch(arrow_left)
ax.add_patch(arrow_right)
# Add Yes/No labels
ax.text((x + x_left) / 2, (y + y_child) / 2, 'Yes',
fontsize=9, color='blue', fontweight='bold')
ax.text((x + x_right) / 2, (y + y_child) / 2, 'No',
fontsize=9, color='red', fontweight='bold')
# Recursively draw children
self.draw_tree(node.left, ax, x_left, y_child, dx * 0.5, level + 1)
self.draw_tree(node.right, ax, x_right, y_child, dx * 0.5, level + 1)
def visualize(self, split_history):
"""Create comprehensive visualization"""
fig = plt.figure(figsize=(20, 10))
gs = fig.add_gridspec(2, 2, height_ratios=[1, 1], width_ratios=[1.2, 1])
ax1 = fig.add_subplot(gs[0, 0]) # Partition view
ax2 = fig.add_subplot(gs[1, 0]) # Decision tree
ax3 = fig.add_subplot(gs[:, 1]) # Statistics
# Parse split history
self.splits = []
if split_history.strip():
for line in split_history.strip().split('\n'):
if ',' in line:
parts = line.split(',')
if len(parts) == 2:
feature = parts[0].strip().lower()
try:
threshold = float(parts[1].strip())
self.splits.append((feature, threshold))
except ValueError:
pass
# Build tree from splits
if self.splits:
self.build_tree_from_splits()
# === Plot 1: Partitioned Feature Space ===
ax1.scatter(self.X[self.y == 0, 0], self.X[self.y == 0, 1],
c='blue', label='Class 0 (Lemon)', s=100, alpha=0.6, edgecolors='k')
ax1.scatter(self.X[self.y == 1, 0], self.X[self.y == 1, 1],
c='orange', label='Class 1 (Orange)', s=100, alpha=0.6, edgecolors='k')
# Draw all partition lines
colors = plt.cm.rainbow(np.linspace(0, 1, len(self.splits)))
for idx, (feature, threshold) in enumerate(self.splits):
if feature == 'x':
ax1.axvline(x=threshold, color=colors[idx], linewidth=2.5,
linestyle='--', label=f'Split {idx+1}: x≤{threshold:.1f}', alpha=0.8)
else:
ax1.axhline(y=threshold, color=colors[idx], linewidth=2.5,
linestyle='--', label=f'Split {idx+1}: y≤{threshold:.1f}', alpha=0.8)
ax1.set_xlabel('Feature 1 (Width)', fontsize=14, fontweight='bold')
ax1.set_ylabel('Feature 2 (Height)', fontsize=14, fontweight='bold')
ax1.set_title('Partitioned Feature Space', fontsize=16, fontweight='bold')
ax1.legend(fontsize=10, loc='upper left')
ax1.grid(True, alpha=0.3)
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
# === Plot 2: Decision Tree ===
ax2.clear()
ax2.set_xlim(0, 1)
ax2.set_ylim(0, 1)
ax2.axis('off')
ax2.set_title('Decision Tree Structure', fontsize=16, fontweight='bold', pad=20)
if self.root:
self.draw_tree(self.root, ax2)
else:
ax2.text(0.5, 0.5, 'No splits yet\nAdd splits to build the tree',
ha='center', va='center', fontsize=14,
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# === Plot 3: Statistics ===
ax3.clear()
ax3.axis('off')
# Calculate overall statistics
entropy_initial = self.calculate_entropy(self.y)
gini_initial = self.calculate_gini(self.y)
stats_text = "DECISION TREE STATISTICS\n" + "="*50 + "\n\n"
stats_text += f"Total Samples: {len(self.y)}\n"
stats_text += f" • Class 0: {np.sum(self.y == 0)}\n"
stats_text += f" • Class 1: {np.sum(self.y == 1)}\n\n"
stats_text += f"Initial Impurity:\n"
stats_text += f" • Entropy: {entropy_initial:.4f}\n"
stats_text += f" • Gini: {gini_initial:.4f}\n\n"
if self.splits:
stats_text += f"Number of Splits: {len(self.splits)}\n\n"
stats_text += "SPLIT SEQUENCE:\n" + "-"*50 + "\n"
for idx, (feature, threshold) in enumerate(self.splits):
feature_name = "Width (x)" if feature == 'x' else "Height (y)"
stats_text += f"\n{idx+1}. {feature_name}{threshold:.2f}\n"
# Get leaf statistics
leaves = []
self._collect_leaves(self.root, leaves)
if leaves:
stats_text += f"\n\nLEAF NODES: {len(leaves)}\n" + "-"*50 + "\n"
for idx, leaf in enumerate(leaves):
stats_text += f"\nLeaf {idx+1}:\n"
stats_text += f" • Samples: {leaf.samples}\n"
stats_text += f" • Class 0: {leaf.class_counts.get(0, 0)} | "
stats_text += f"Class 1: {leaf.class_counts.get(1, 0)}\n"
stats_text += f" • Prediction: Class {leaf.majority_class}\n"
stats_text += f" • Entropy: {leaf.entropy:.4f}\n"
stats_text += f" • Gini: {leaf.gini:.4f}\n"
# Calculate weighted average impurity
total_samples = sum(leaf.samples for leaf in leaves)
avg_entropy = sum(leaf.entropy * leaf.samples for leaf in leaves) / total_samples
avg_gini = sum(leaf.gini * leaf.samples for leaf in leaves) / total_samples
stats_text += f"\n\nWEIGHTED AVERAGE IMPURITY:\n" + "-"*50 + "\n"
stats_text += f" • Entropy: {avg_entropy:.4f}\n"
stats_text += f" • Gini: {avg_gini:.4f}\n"
stats_text += f"\nTOTAL INFORMATION GAIN:\n"
stats_text += f" • {entropy_initial - avg_entropy:.4f}\n"
stats_text += f"\nTOTAL GINI REDUCTION:\n"
stats_text += f" • {gini_initial - avg_gini:.4f}\n"
else:
stats_text += "No splits applied yet.\n"
stats_text += "\nAdd splits in the format:\n"
stats_text += " feature, threshold\n\n"
stats_text += "Example:\n"
stats_text += " x, 5.0\n"
stats_text += " y, 6.5\n"
ax3.text(0.05, 0.95, stats_text, transform=ax3.transAxes,
fontsize=10, verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
family='monospace')
plt.tight_layout()
# Convert to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
plt.close()
return img
def _collect_leaves(self, node, leaves):
"""Collect all leaf nodes"""
if node is None:
return
if node.is_leaf:
leaves.append(node)
else:
self._collect_leaves(node.left, leaves)
self._collect_leaves(node.right, leaves)
# Create the partitioner
partitioner = DecisionTreePartitioner()
# Create Gradio interface
with gr.Blocks(title="Multi-Split Decision Tree Visualizer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🌳 Interactive Multi-Split Decision Tree Visualizer
Build a decision tree step-by-step and visualize the partitioning process!
""")
with gr.Row():
with gr.Column(scale=1):
split_input = gr.Textbox(
label="📝 Split Sequence (one per line: feature, threshold)",
placeholder="x, 5.0\ny, 6.5\nx, 3.0",
lines=10,
value="x, 5.0"
)
update_btn = gr.Button("🔄 Update Visualization", variant="primary", size="lg")
gr.Markdown("""
### Example Splits:
**Simple 2-split tree:**
```
x, 5.0
y, 6.5
```
**Complex 4-split tree:**
```
x, 5.0
y, 6.5
x, 3.0
y, 8.0
```
""")
with gr.Column(scale=2):
output_image = gr.Image(label="Visualization", height=800)
# Update visualization
update_btn.click(
fn=partitioner.visualize,
inputs=[split_input],
outputs=output_image
)
# Initial visualization
demo.load(
fn=partitioner.visualize,
inputs=[split_input],
outputs=output_image
)
# Launch the app
if __name__ == "__main__":
demo.launch()