rajendrr commited on
Commit
494d188
·
verified ·
1 Parent(s): 864a422

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -14
app.py CHANGED
@@ -1,15 +1,13 @@
1
  import gradio as gr
2
- import re
3
- import pandas as pd
4
- import numpy as np
5
  from sklearn.datasets import load_iris, load_wine, load_breast_cancer
6
  from sklearn.model_selection import train_test_split
7
  from sklearn.tree import DecisionTreeClassifier, plot_tree
8
  import matplotlib.pyplot as plt
9
  import pandas as pd
 
10
  import io
11
 
12
- # Load datasets into a dictionary for easy selection
13
  DATASETS = {
14
  "Iris": load_iris(),
15
  "Wine": load_wine(),
@@ -17,15 +15,12 @@ DATASETS = {
17
  }
18
 
19
  def train_and_plot(dataset_name, criterion, splitter, max_depth):
20
- # Load chosen dataset
21
  data = DATASETS[dataset_name]
22
  X = pd.DataFrame(data.data, columns=data.feature_names)
23
  y = data.target
24
 
25
- # Split into train/test
26
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
27
 
28
- # Create Decision Tree model
29
  clf = DecisionTreeClassifier(
30
  criterion=criterion,
31
  splitter=splitter,
@@ -33,24 +28,21 @@ def train_and_plot(dataset_name, criterion, splitter, max_depth):
33
  random_state=42
34
  )
35
  clf.fit(X_train, y_train)
36
-
37
- # Calculate accuracy
38
  accuracy = clf.score(X_test, y_test)
39
 
40
- # Plot decision tree
41
  fig, ax = plt.subplots(figsize=(12, 8))
42
  plot_tree(clf, feature_names=data.feature_names, class_names=data.target_names, filled=True, ax=ax)
43
  plt.tight_layout()
44
 
45
- # Save plot to a file-like object
46
  buf = io.BytesIO()
47
  plt.savefig(buf, format="png")
48
  buf.seek(0)
 
49
  plt.close(fig)
50
 
51
- return f"Model Accuracy: {accuracy:.2%}", buf
52
 
53
- # Create Gradio interface
54
  demo = gr.Interface(
55
  fn=train_and_plot,
56
  inputs=[
@@ -61,7 +53,7 @@ demo = gr.Interface(
61
  ],
62
  outputs=[
63
  gr.Textbox(label="Accuracy"),
64
- gr.Image(type="file", label="Decision Tree Plot")
65
  ],
66
  title="Interactive Decision Tree Classifier",
67
  description="Adjust the parameters and see how the Decision Tree changes."
 
1
  import gradio as gr
 
 
 
2
  from sklearn.datasets import load_iris, load_wine, load_breast_cancer
3
  from sklearn.model_selection import train_test_split
4
  from sklearn.tree import DecisionTreeClassifier, plot_tree
5
  import matplotlib.pyplot as plt
6
  import pandas as pd
7
+ from PIL import Image
8
  import io
9
 
10
+ # Datasets
11
  DATASETS = {
12
  "Iris": load_iris(),
13
  "Wine": load_wine(),
 
15
  }
16
 
17
  def train_and_plot(dataset_name, criterion, splitter, max_depth):
 
18
  data = DATASETS[dataset_name]
19
  X = pd.DataFrame(data.data, columns=data.feature_names)
20
  y = data.target
21
 
 
22
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
23
 
 
24
  clf = DecisionTreeClassifier(
25
  criterion=criterion,
26
  splitter=splitter,
 
28
  random_state=42
29
  )
30
  clf.fit(X_train, y_train)
 
 
31
  accuracy = clf.score(X_test, y_test)
32
 
33
+ # Plot tree
34
  fig, ax = plt.subplots(figsize=(12, 8))
35
  plot_tree(clf, feature_names=data.feature_names, class_names=data.target_names, filled=True, ax=ax)
36
  plt.tight_layout()
37
 
 
38
  buf = io.BytesIO()
39
  plt.savefig(buf, format="png")
40
  buf.seek(0)
41
+ img = Image.open(buf) # Convert to PIL
42
  plt.close(fig)
43
 
44
+ return f"Model Accuracy: {accuracy:.2%}", img
45
 
 
46
  demo = gr.Interface(
47
  fn=train_and_plot,
48
  inputs=[
 
53
  ],
54
  outputs=[
55
  gr.Textbox(label="Accuracy"),
56
+ gr.Image(type="pil", label="Decision Tree Plot")
57
  ],
58
  title="Interactive Decision Tree Classifier",
59
  description="Adjust the parameters and see how the Decision Tree changes."