selva1909 commited on
Commit
20b258b
·
verified ·
1 Parent(s): 687c512

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -57
app.py CHANGED
@@ -1,15 +1,20 @@
1
  import gradio as gr
2
  import numpy as np
3
  import plotly.graph_objects as go
4
- import matplotlib.pyplot as plt
5
  from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
6
  from sklearn.metrics import mean_squared_error, accuracy_score
7
- from sklearn.tree import plot_tree
8
 
9
  # ------------------------------------------------
10
  # DATA GENERATORS
11
  # ------------------------------------------------
12
 
 
 
 
 
 
 
 
13
  def generate_3d_regression(n_points, noise):
14
  n_points = int(n_points)
15
  x1 = np.linspace(0, 10, n_points)
@@ -19,7 +24,7 @@ def generate_3d_regression(n_points, noise):
19
 
20
  X_flat = np.column_stack((X1.ravel(), X2.ravel()))
21
  Z_flat = Z.ravel()
22
- return X1, X2, X_flat, Z_flat
23
 
24
 
25
  def generate_classification(n_points, noise):
@@ -31,54 +36,87 @@ def generate_classification(n_points, noise):
31
 
32
 
33
  # ------------------------------------------------
34
- # INTERACTIVE 3D RANDOM FOREST
35
  # ------------------------------------------------
36
 
37
- def interactive_3d(n_points, noise, n_estimators, max_depth):
38
  n_estimators = int(n_estimators)
39
  max_depth = int(max_depth)
40
 
41
- X1, X2, X_flat, Z_flat = generate_3d_regression(n_points, noise)
42
 
43
  rf = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
44
- rf.fit(X_flat, Z_flat)
45
 
46
- Z_pred = rf.predict(X_flat).reshape(X1.shape)
47
- mse = mean_squared_error(Z_flat, Z_pred.ravel())
48
 
49
  fig = go.Figure()
50
 
51
- fig.add_surface(x=X1, y=X2, z=Z_pred, colorscale="Blues", opacity=0.9)
 
52
 
53
  fig.update_layout(
54
- title="Interactive 3D Random Forest Surface",
55
- scene=dict(bgcolor="#0b1e3d"),
56
  paper_bgcolor="#0b1e3d",
 
57
  font=dict(color="white")
58
  )
59
 
60
- return fig, f"MSE: {mse:.4f}", rf
61
 
62
 
63
  # ------------------------------------------------
64
- # SINGLE TREE VISUALIZATION
65
  # ------------------------------------------------
66
 
67
- def show_tree(rf_model):
68
- if rf_model is None:
69
- return None
 
 
 
 
 
70
 
71
- tree = rf_model.estimators_[0]
 
72
 
73
- fig, ax = plt.subplots(figsize=(10, 6))
74
- plot_tree(tree, ax=ax, filled=True)
75
- ax.set_title("Single Decision Tree from Random Forest")
76
 
77
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # ------------------------------------------------
81
- # CLASSIFICATION VIEW
82
  # ------------------------------------------------
83
 
84
  def classification_view(n_points, noise, n_estimators, max_depth):
@@ -90,20 +128,24 @@ def classification_view(n_points, noise, n_estimators, max_depth):
90
  rf = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
91
  rf.fit(X, y)
92
 
93
- y_pred = rf.predict(X)
94
- acc = accuracy_score(y, y_pred)
 
 
 
 
 
 
 
95
 
96
  fig = go.Figure()
97
 
98
- fig.add_trace(go.Scatter(
99
- x=X[:, 0],
100
- y=X[:, 1],
101
- mode="markers",
102
- marker=dict(color=y_pred, colorscale="Blues"),
103
- ))
104
 
105
  fig.update_layout(
106
- title="Random Forest Classification",
107
  paper_bgcolor="#0b1e3d",
108
  plot_bgcolor="#0b1e3d",
109
  font=dict(color="white")
@@ -113,41 +155,44 @@ def classification_view(n_points, noise, n_estimators, max_depth):
113
 
114
 
115
  # ------------------------------------------------
116
- # GRADIO UI (AUTO RUN, BLUE THEME)
117
  # ------------------------------------------------
118
 
119
  with gr.Blocks() as demo:
120
 
121
- gr.Markdown("# 🌲 Random Forest Teaching Dashboard")
122
-
123
- with gr.Tab("🌐 3D Regression"):
124
- n_points = gr.Slider(20, 100, 40, step=1, label="Points")
125
- noise = gr.Slider(0.0, 3.0, 1.0, label="Noise")
126
- n_estimators = gr.Slider(10, 100, 50, step=1, label="Trees")
127
- max_depth = gr.Slider(2, 15, 8, step=1, label="Depth")
128
-
129
- plot3d = gr.Plot()
130
- mse_text = gr.Markdown()
131
- tree_plot = gr.Plot()
132
 
133
- def run_3d(n_points, noise, n_estimators, max_depth):
134
- fig, text, rf = interactive_3d(n_points, noise, n_estimators, max_depth)
135
- tree_fig = show_tree(rf)
136
- return fig, text, tree_fig
137
 
138
- for inp in [n_points, noise, n_estimators, max_depth]:
139
- inp.change(run_3d, [n_points, noise, n_estimators, max_depth], [plot3d, mse_text, tree_plot])
 
140
 
141
- demo.load(run_3d, [n_points, noise, n_estimators, max_depth], [plot3d, mse_text, tree_plot])
 
 
142
 
143
  with gr.Tab("🧩 Classification"):
144
- cls_plot = gr.Plot()
145
- cls_text = gr.Markdown()
 
 
 
 
 
 
146
 
147
- for inp in [n_points, noise, n_estimators, max_depth]:
148
- inp.change(classification_view, [n_points, noise, n_estimators, max_depth], [cls_plot, cls_text])
 
 
149
 
150
- demo.load(classification_view, [n_points, noise, n_estimators, max_depth], [cls_plot, cls_text])
 
 
151
 
152
 
153
- demo.launch(theme=gr.themes.Soft(primary_hue="blue"))
 
1
  import gradio as gr
2
  import numpy as np
3
  import plotly.graph_objects as go
 
4
  from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
5
  from sklearn.metrics import mean_squared_error, accuracy_score
 
6
 
7
  # ------------------------------------------------
8
  # DATA GENERATORS
9
  # ------------------------------------------------
10
 
11
+ def generate_2d_regression(n_points, noise):
12
+ n_points = int(n_points)
13
+ X = np.linspace(0, 10, n_points)
14
+ y = 2.5 * X + 5 + np.random.randn(n_points) * noise
15
+ return X.reshape(-1, 1), y
16
+
17
+
18
  def generate_3d_regression(n_points, noise):
19
  n_points = int(n_points)
20
  x1 = np.linspace(0, 10, n_points)
 
24
 
25
  X_flat = np.column_stack((X1.ravel(), X2.ravel()))
26
  Z_flat = Z.ravel()
27
+ return X, X1, X2, X_flat, Z_flat
28
 
29
 
30
  def generate_classification(n_points, noise):
 
36
 
37
 
38
  # ------------------------------------------------
39
+ # 2D RANDOM FOREST REGRESSION
40
  # ------------------------------------------------
41
 
42
+ def rf_2d_view(n_points, noise, n_estimators, max_depth):
43
  n_estimators = int(n_estimators)
44
  max_depth = int(max_depth)
45
 
46
+ X, y = generate_2d_regression(n_points, noise)
47
 
48
  rf = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
49
+ rf.fit(X, y)
50
 
51
+ y_pred = rf.predict(X)
52
+ mse = mean_squared_error(y, y_pred)
53
 
54
  fig = go.Figure()
55
 
56
+ fig.add_scatter(x=X.flatten(), y=y, mode="markers", name="Data")
57
+ fig.add_scatter(x=X.flatten(), y=y_pred, mode="lines", name="RF Prediction")
58
 
59
  fig.update_layout(
60
+ title="2D Random Forest Regression",
 
61
  paper_bgcolor="#0b1e3d",
62
+ plot_bgcolor="#0b1e3d",
63
  font=dict(color="white")
64
  )
65
 
66
+ return fig, f"MSE: {mse:.4f}"
67
 
68
 
69
  # ------------------------------------------------
70
+ # 3D RANDOM FOREST REGRESSION (ROTATING CAMERA)
71
  # ------------------------------------------------
72
 
73
+ def rf_3d_view(n_points, noise, n_estimators, max_depth):
74
+ n_estimators = int(n_estimators)
75
+ max_depth = int(max_depth)
76
+
77
+ _, X1, X2, X_flat, Z_flat = generate_3d_regression(n_points, noise)
78
+
79
+ rf = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
80
+ rf.fit(X_flat, Z_flat)
81
 
82
+ Z_pred = rf.predict(X_flat).reshape(X1.shape)
83
+ mse = mean_squared_error(Z_flat, Z_pred.ravel())
84
 
85
+ fig = go.Figure()
 
 
86
 
87
+ fig.add_surface(x=X1, y=X2, z=Z_pred, colorscale="Blues", opacity=0.9)
88
+
89
+ # Smooth rotating camera animation
90
+ frames = []
91
+ for angle in range(0, 360, 10):
92
+ frames.append(go.Frame(layout=dict(scene_camera=dict(eye=dict(
93
+ x=np.cos(np.radians(angle)) * 2,
94
+ y=np.sin(np.radians(angle)) * 2,
95
+ z=1.2
96
+ )))))
97
+
98
+ fig.frames = frames
99
+
100
+ fig.update_layout(
101
+ title="3D Random Forest Regression",
102
+ scene=dict(bgcolor="#0b1e3d"),
103
+ paper_bgcolor="#0b1e3d",
104
+ font=dict(color="white"),
105
+ updatemenus=[dict(
106
+ type="buttons",
107
+ showactive=False,
108
+ buttons=[dict(label="Rotate 3D",
109
+ method="animate",
110
+ args=[None, {"frame": {"duration": 60, "redraw": True},
111
+ "fromcurrent": True}])]
112
+ )]
113
+ )
114
+
115
+ return fig, f"MSE: {mse:.4f}"
116
 
117
 
118
  # ------------------------------------------------
119
+ # CLASSIFICATION VIEW (CLEAR DECISION REGIONS)
120
  # ------------------------------------------------
121
 
122
  def classification_view(n_points, noise, n_estimators, max_depth):
 
128
  rf = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
129
  rf.fit(X, y)
130
 
131
+ # Mesh grid for decision boundary
132
+ xx, yy = np.meshgrid(
133
+ np.linspace(X[:, 0].min() - 1, X[:, 0].max() + 1, 100),
134
+ np.linspace(X[:, 1].min() - 1, X[:, 1].max() + 1, 100)
135
+ )
136
+
137
+ Z = rf.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
138
+
139
+ acc = accuracy_score(y, rf.predict(X))
140
 
141
  fig = go.Figure()
142
 
143
+ fig.add_contour(x=xx[0], y=yy[:, 0], z=Z, showscale=False, opacity=0.4, colorscale="Blues")
144
+
145
+ fig.add_scatter(x=X[:, 0], y=X[:, 1], mode="markers", marker=dict(color=y, colorscale="Blues"), name="Data")
 
 
 
146
 
147
  fig.update_layout(
148
+ title="Random Forest Classification Boundary",
149
  paper_bgcolor="#0b1e3d",
150
  plot_bgcolor="#0b1e3d",
151
  font=dict(color="white")
 
155
 
156
 
157
  # ------------------------------------------------
158
+ # GRADIO UI (AUTO-RUN, BEGINNER FRIENDLY)
159
  # ------------------------------------------------
160
 
161
  with gr.Blocks() as demo:
162
 
163
+ gr.Markdown("# 🌲 Random Forest Learning Dashboard (2D • 3D • Classification)")
 
 
 
 
 
 
 
 
 
 
164
 
165
+ n_points = gr.Slider(20, 100, 40, step=1, label="Number of Points")
166
+ noise = gr.Slider(0.0, 3.0, 1.0, label="Noise Level")
167
+ n_estimators = gr.Slider(10, 100, 50, step=1, label="Number of Trees")
168
+ max_depth = gr.Slider(2, 15, 8, step=1, label="Tree Depth")
169
 
170
+ with gr.Tab("📈 2D Regression"):
171
+ plot2d = gr.Plot()
172
+ mse2d = gr.Markdown()
173
 
174
+ with gr.Tab("🌐 3D Regression"):
175
+ plot3d = gr.Plot()
176
+ mse3d = gr.Markdown()
177
 
178
  with gr.Tab("🧩 Classification"):
179
+ plot_cls = gr.Plot()
180
+ acc_cls = gr.Markdown()
181
+
182
+ def run_all(n_points, noise, n_estimators, max_depth):
183
+ fig2d, mse_text = rf_2d_view(n_points, noise, n_estimators, max_depth)
184
+ fig3d, mse3d_text = rf_3d_view(n_points, noise, n_estimators, max_depth)
185
+ fig_cls, acc_text = classification_view(n_points, noise, n_estimators, max_depth)
186
+ return fig2d, mse_text, fig3d, mse3d_text, fig_cls, acc_text
187
 
188
+ for inp in [n_points, noise, n_estimators, max_depth]:
189
+ inp.change(run_all,
190
+ [n_points, noise, n_estimators, max_depth],
191
+ [plot2d, mse2d, plot3d, mse3d, plot_cls, acc_cls])
192
 
193
+ demo.load(run_all,
194
+ [n_points, noise, n_estimators, max_depth],
195
+ [plot2d, mse2d, plot3d, mse3d, plot_cls, acc_cls])
196
 
197
 
198
+ demo.launch(theme=gr.themes.Soft(primary_hue="blue"))