| import gradio as gr |
| from sklearn.datasets import load_iris, load_wine, load_breast_cancer |
| from sklearn.model_selection import train_test_split |
| from sklearn.tree import DecisionTreeClassifier, plot_tree |
| import matplotlib.pyplot as plt |
| import pandas as pd |
| import io |
|
|
| |
| DATASETS = { |
| "Iris": load_iris(), |
| "Wine": load_wine(), |
| "Breast Cancer": load_breast_cancer() |
| } |
|
|
| def train_and_plot(dataset_name, criterion, splitter, max_depth): |
| # Load chosen dataset |
| data = DATASETS[dataset_name] |
| X = pd.DataFrame(data.data, columns=data.feature_names) |
| y = data.target |
|
|
| # Split into train/test |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) |
|
|
| # Create Decision Tree model |
| clf = DecisionTreeClassifier( |
| criterion=criterion, |
| splitter=splitter, |
| max_depth=None if max_depth == 0 else max_depth, |
| random_state=42 |
| ) |
| clf.fit(X_train, y_train) |
|
|
| # Calculate accuracy |
| accuracy = clf.score(X_test, y_test) |
|
|
| # Plot decision tree |
| fig, ax = plt.subplots(figsize=(12, 8)) |
| plot_tree(clf, feature_names=data.feature_names, class_names=data.target_names, filled=True, ax=ax) |
| plt.tight_layout() |
|
|
| # Save plot to a file-like object |
| buf = io.BytesIO() |
| plt.savefig(buf, format="png") |
| buf.seek(0) |
| plt.close(fig) |
|
|
| return f"Model Accuracy: {accuracy:.2%}", buf |
|
|
| |
| demo = gr.Interface( |
| fn=train_and_plot, |
| inputs=[ |
| gr.Dropdown(list(DATASETS.keys()), label="Choose Dataset"), |
| gr.Dropdown(["gini", "entropy", "log_loss"], label="Criterion"), |
| gr.Dropdown(["best", "random"], label="Splitter"), |
| gr.Slider(0, 10, step=1, value=3, label="Max Depth (0 for unlimited)") |
| ], |
| outputs=[ |
| gr.Textbox(label="Accuracy"), |
| gr.Image(type="file", label="Decision Tree Plot") |
| ], |
| title="Interactive Decision Tree Classifier", |
| description="Adjust the parameters and see how the Decision Tree changes." |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|