| | import gradio as gr |
| | from sklearn.datasets import load_iris, load_wine, load_breast_cancer |
| | from sklearn.model_selection import train_test_split |
| | from sklearn.tree import DecisionTreeClassifier, plot_tree |
| | import matplotlib.pyplot as plt |
| | import pandas as pd |
| | import io |
| |
|
| | |
| | DATASETS = { |
| | "Iris": load_iris(), |
| | "Wine": load_wine(), |
| | "Breast Cancer": load_breast_cancer() |
| | } |
| |
|
| | def train_and_plot(dataset_name, criterion, splitter, max_depth): |
| | # Load chosen dataset |
| | data = DATASETS[dataset_name] |
| | X = pd.DataFrame(data.data, columns=data.feature_names) |
| | y = data.target |
| |
|
| | # Split into train/test |
| | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) |
| |
|
| | # Create Decision Tree model |
| | clf = DecisionTreeClassifier( |
| | criterion=criterion, |
| | splitter=splitter, |
| | max_depth=None if max_depth == 0 else max_depth, |
| | random_state=42 |
| | ) |
| | clf.fit(X_train, y_train) |
| |
|
| | # Calculate accuracy |
| | accuracy = clf.score(X_test, y_test) |
| |
|
| | # Plot decision tree |
| | fig, ax = plt.subplots(figsize=(12, 8)) |
| | plot_tree(clf, feature_names=data.feature_names, class_names=data.target_names, filled=True, ax=ax) |
| | plt.tight_layout() |
| |
|
| | # Save plot to a file-like object |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format="png") |
| | buf.seek(0) |
| | plt.close(fig) |
| |
|
| | return f"Model Accuracy: {accuracy:.2%}", buf |
| |
|
| | |
| | demo = gr.Interface( |
| | fn=train_and_plot, |
| | inputs=[ |
| | gr.Dropdown(list(DATASETS.keys()), label="Choose Dataset"), |
| | gr.Dropdown(["gini", "entropy", "log_loss"], label="Criterion"), |
| | gr.Dropdown(["best", "random"], label="Splitter"), |
| | gr.Slider(0, 10, step=1, value=3, label="Max Depth (0 for unlimited)") |
| | ], |
| | outputs=[ |
| | gr.Textbox(label="Accuracy"), |
| | gr.Image(type="file", label="Decision Tree Plot") |
| | ], |
| | title="Interactive Decision Tree Classifier", |
| | description="Adjust the parameters and see how the Decision Tree changes." |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|