rajendrr commited on
Commit
146bf7a
·
verified ·
1 Parent(s): 780c52e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sklearn.datasets import load_iris, load_wine, load_breast_cancer
3
+ from sklearn.model_selection import train_test_split
4
+ from sklearn.tree import DecisionTreeClassifier, plot_tree
5
+ import matplotlib.pyplot as plt
6
+ import pandas as pd
7
+ import io
8
+
9
+ # Load datasets into a dictionary for easy selection
10
+ DATASETS = {
11
+ "Iris": load_iris(),
12
+ "Wine": load_wine(),
13
+ "Breast Cancer": load_breast_cancer()
14
+ }
15
+
16
+ def train_and_plot(dataset_name, criterion, splitter, max_depth):
17
+ # Load chosen dataset
18
+ data = DATASETS[dataset_name]
19
+ X = pd.DataFrame(data.data, columns=data.feature_names)
20
+ y = data.target
21
+
22
+ # Split into train/test
23
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
24
+
25
+ # Create Decision Tree model
26
+ clf = DecisionTreeClassifier(
27
+ criterion=criterion,
28
+ splitter=splitter,
29
+ max_depth=None if max_depth == 0 else max_depth,
30
+ random_state=42
31
+ )
32
+ clf.fit(X_train, y_train)
33
+
34
+ # Calculate accuracy
35
+ accuracy = clf.score(X_test, y_test)
36
+
37
+ # Plot decision tree
38
+ fig, ax = plt.subplots(figsize=(12, 8))
39
+ plot_tree(clf, feature_names=data.feature_names, class_names=data.target_names, filled=True, ax=ax)
40
+ plt.tight_layout()
41
+
42
+ # Save plot to a file-like object
43
+ buf = io.BytesIO()
44
+ plt.savefig(buf, format="png")
45
+ buf.seek(0)
46
+ plt.close(fig)
47
+
48
+ return f"Model Accuracy: {accuracy:.2%}", buf
49
+
50
+ # Create Gradio interface
51
+ demo = gr.Interface(
52
+ fn=train_and_plot,
53
+ inputs=[
54
+ gr.Dropdown(list(DATASETS.keys()), label="Choose Dataset"),
55
+ gr.Dropdown(["gini", "entropy", "log_loss"], label="Criterion"),
56
+ gr.Dropdown(["best", "random"], label="Splitter"),
57
+ gr.Slider(0, 10, step=1, value=3, label="Max Depth (0 for unlimited)")
58
+ ],
59
+ outputs=[
60
+ gr.Textbox(label="Accuracy"),
61
+ gr.Image(type="file", label="Decision Tree Plot")
62
+ ],
63
+ title="Interactive Decision Tree Classifier",
64
+ description="Adjust the parameters and see how the Decision Tree changes."
65
+ )
66
+
67
+ if __name__ == "__main__":
68
+ demo.launch()