File size: 1,944 Bytes
146bf7a
 
 
 
 
 
494d188
146bf7a
 
494d188
146bf7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494d188
146bf7a
 
 
 
 
 
 
494d188
146bf7a
 
494d188
146bf7a
 
 
 
 
 
 
 
 
 
 
494d188
146bf7a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
from sklearn.datasets import load_iris, load_wine, load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
import io

# Datasets
DATASETS = {
    "Iris": load_iris(),
    "Wine": load_wine(),
    "Breast Cancer": load_breast_cancer()
}

def train_and_plot(dataset_name, criterion, splitter, max_depth):
    data = DATASETS[dataset_name]
    X = pd.DataFrame(data.data, columns=data.feature_names)
    y = data.target

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    clf = DecisionTreeClassifier(
        criterion=criterion,
        splitter=splitter,
        max_depth=None if max_depth == 0 else max_depth,
        random_state=42
    )
    clf.fit(X_train, y_train)
    accuracy = clf.score(X_test, y_test)

    # Plot tree
    fig, ax = plt.subplots(figsize=(12, 8))
    plot_tree(clf, feature_names=data.feature_names, class_names=data.target_names, filled=True, ax=ax)
    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    img = Image.open(buf)  # Convert to PIL
    plt.close(fig)

    return f"Model Accuracy: {accuracy:.2%}", img

demo = gr.Interface(
    fn=train_and_plot,
    inputs=[
        gr.Dropdown(list(DATASETS.keys()), label="Choose Dataset"),
        gr.Dropdown(["gini", "entropy", "log_loss"], label="Criterion"),
        gr.Dropdown(["best", "random"], label="Splitter"),
        gr.Slider(0, 10, step=1, value=3, label="Max Depth (0 for unlimited)")
    ],
    outputs=[
        gr.Textbox(label="Accuracy"),
        gr.Image(type="pil", label="Decision Tree Plot")
    ],
    title="Interactive Decision Tree Classifier",
    description="Adjust the parameters and see how the Decision Tree changes."
)

if __name__ == "__main__":
    demo.launch()