|
|
import gradio as gr |
|
|
from sklearn.datasets import load_iris, load_wine, load_breast_cancer |
|
|
from sklearn.model_selection import train_test_split |
|
|
from sklearn.tree import DecisionTreeClassifier, plot_tree |
|
|
import matplotlib.pyplot as plt |
|
|
import pandas as pd |
|
|
import io |
|
|
|
|
|
|
|
|
DATASETS = { |
|
|
"Iris": load_iris(), |
|
|
"Wine": load_wine(), |
|
|
"Breast Cancer": load_breast_cancer() |
|
|
} |
|
|
|
|
|
def train_and_plot(dataset_name, criterion, splitter, max_depth): |
|
|
# Load chosen dataset |
|
|
data = DATASETS[dataset_name] |
|
|
X = pd.DataFrame(data.data, columns=data.feature_names) |
|
|
y = data.target |
|
|
|
|
|
# Split into train/test |
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) |
|
|
|
|
|
# Create Decision Tree model |
|
|
clf = DecisionTreeClassifier( |
|
|
criterion=criterion, |
|
|
splitter=splitter, |
|
|
max_depth=None if max_depth == 0 else max_depth, |
|
|
random_state=42 |
|
|
) |
|
|
clf.fit(X_train, y_train) |
|
|
|
|
|
# Calculate accuracy |
|
|
accuracy = clf.score(X_test, y_test) |
|
|
|
|
|
# Plot decision tree |
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
|
plot_tree(clf, feature_names=data.feature_names, class_names=data.target_names, filled=True, ax=ax) |
|
|
plt.tight_layout() |
|
|
|
|
|
# Save plot to a file-like object |
|
|
buf = io.BytesIO() |
|
|
plt.savefig(buf, format="png") |
|
|
buf.seek(0) |
|
|
plt.close(fig) |
|
|
|
|
|
return f"Model Accuracy: {accuracy:.2%}", buf |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=train_and_plot, |
|
|
inputs=[ |
|
|
gr.Dropdown(list(DATASETS.keys()), label="Choose Dataset"), |
|
|
gr.Dropdown(["gini", "entropy", "log_loss"], label="Criterion"), |
|
|
gr.Dropdown(["best", "random"], label="Splitter"), |
|
|
gr.Slider(0, 10, step=1, value=3, label="Max Depth (0 for unlimited)") |
|
|
], |
|
|
outputs=[ |
|
|
gr.Textbox(label="Accuracy"), |
|
|
gr.Image(type="file", label="Decision Tree Plot") |
|
|
], |
|
|
title="Interactive Decision Tree Classifier", |
|
|
description="Adjust the parameters and see how the Decision Tree changes." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|