selva1909 commited on
Commit
a9bf4b6
·
verified ·
1 Parent(s): 114cfce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -60
app.py CHANGED
@@ -1,16 +1,43 @@
1
  import gradio as gr
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
- from sklearn.ensemble import RandomForestRegressor
5
- from sklearn.metrics import mean_squared_error
6
- from mpl_toolkits.mplot3d import Axes3D
7
 
8
  # ------------------------------------------------
9
- # Random Forest 2D
10
  # ------------------------------------------------
11
- def rf_2d(n_points, noise, n_estimators, max_depth):
 
12
  X = np.linspace(0, 10, n_points).reshape(-1, 1)
13
  y = 2.5 * X.flatten() + 5 + np.random.randn(n_points) * noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  model = RandomForestRegressor(
16
  n_estimators=n_estimators,
@@ -23,26 +50,31 @@ def rf_2d(n_points, noise, n_estimators, max_depth):
23
  mse = mean_squared_error(y, y_pred)
24
 
25
  fig, ax = plt.subplots(figsize=(5, 4))
26
- ax.scatter(X, y, s=20, color="orange", label="Data")
27
- ax.plot(X, y_pred, color="blue", linewidth=2, label="RF Prediction")
 
 
 
 
 
 
28
  ax.set_title("2D Random Forest Regression")
29
  ax.legend()
30
 
31
- return fig, f"MSE: {mse:.4f}"
 
 
 
 
 
32
 
33
 
34
  # ------------------------------------------------
35
- # Random Forest 3D
36
  # ------------------------------------------------
37
- def rf_3d(n_points, noise, n_estimators, max_depth):
38
- x1 = np.linspace(0, 10, n_points)
39
- x2 = np.linspace(0, 10, n_points)
40
- X1, X2 = np.meshgrid(x1, x2)
41
-
42
- Z = 3 * X1 + 2 * X2 + 10 + np.random.randn(*X1.shape) * noise
43
 
44
- X_flat = np.column_stack((X1.ravel(), X2.ravel()))
45
- Z_flat = Z.ravel()
46
 
47
  model = RandomForestRegressor(
48
  n_estimators=n_estimators,
@@ -54,66 +86,95 @@ def rf_3d(n_points, noise, n_estimators, max_depth):
54
  Z_pred = model.predict(X_flat).reshape(X1.shape)
55
  mse = mean_squared_error(Z_flat, Z_pred.ravel())
56
 
57
- fig = plt.figure(figsize=(5, 4))
58
- ax = fig.add_subplot(111, projection="3d")
59
-
60
- idx = np.random.choice(len(Z_flat), 400, replace=False)
61
- ax.scatter(
62
- X_flat[idx, 0],
63
- X_flat[idx, 1],
64
- Z_flat[idx],
65
- s=8,
66
- alpha=0.3,
67
- color="orange"
68
- )
69
 
70
- ax.plot_surface(X1, X2, Z_pred, alpha=0.7, color="blue")
71
- ax.set_title("3D Random Forest Surface")
 
72
 
73
- return fig, f"MSE: {mse:.4f}"
 
 
74
 
75
 
76
  # ------------------------------------------------
77
- # Gradio UI
78
  # ------------------------------------------------
79
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
80
- gr.Markdown(
81
- """
82
- # 🌲 Random Forest Regression Visualizer
83
- Interactive **2D & 3D Random Forest** playground
84
- """
 
 
85
  )
 
 
 
 
86
 
87
- with gr.Row():
88
- mode = gr.Radio(
89
- ["2D Regression", "3D Regression"],
90
- value="2D Regression",
91
- label="Mode"
92
- )
 
 
93
 
94
- with gr.Row():
95
- n_points = gr.Slider(20, 200, value=80, step=10, label="Data Points")
96
- noise = gr.Slider(0.0, 5.0, value=1.0, label="Noise Level")
97
 
98
- with gr.Row():
99
- n_estimators = gr.Slider(10, 200, value=50, step=10, label="Trees")
100
- max_depth = gr.Slider(2, 20, value=8, step=1, label="Max Depth")
101
 
102
- run_btn = gr.Button("🌲 Train Random Forest")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  plot = gr.Plot()
 
105
  metric = gr.Markdown()
106
 
107
- def run(mode, n_points, noise, n_estimators, max_depth):
108
  if mode == "2D Regression":
109
- return rf_2d(n_points, noise, n_estimators, max_depth)
 
 
 
110
  else:
111
- return rf_3d(n_points, noise, n_estimators, max_depth)
 
 
 
 
 
 
 
112
 
113
- run_btn.click(
114
- run,
115
- inputs=[mode, n_points, noise, n_estimators, max_depth],
116
- outputs=[plot, metric]
117
- )
118
 
119
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
+ from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
5
+ from sklearn.metrics import mean_squared_error, accuracy_score
 
6
 
7
  # ------------------------------------------------
8
+ # Helpers
9
  # ------------------------------------------------
10
+
11
+ def generate_2d_regression(n_points, noise):
12
  X = np.linspace(0, 10, n_points).reshape(-1, 1)
13
  y = 2.5 * X.flatten() + 5 + np.random.randn(n_points) * noise
14
+ return X, y
15
+
16
+
17
+ def generate_3d_regression(n_points, noise):
18
+ x1 = np.linspace(0, 10, n_points)
19
+ x2 = np.linspace(0, 10, n_points)
20
+ X1, X2 = np.meshgrid(x1, x2)
21
+ Z = 3 * X1 + 2 * X2 + 10 + np.random.randn(*X1.shape) * noise
22
+
23
+ X_flat = np.column_stack((X1.ravel(), X2.ravel()))
24
+ Z_flat = Z.ravel()
25
+ return X1, X2, X_flat, Z, Z_flat
26
+
27
+
28
+ def generate_classification(n_points, noise):
29
+ X = np.random.randn(n_points, 2)
30
+ y = (X[:, 0]**2 + X[:, 1] > 0.5).astype(int)
31
+ X += np.random.randn(*X.shape) * noise * 0.1
32
+ return X, y
33
+
34
+
35
+ # ------------------------------------------------
36
+ # 2D Random Forest Regression + Trees + Importance
37
+ # ------------------------------------------------
38
+
39
+ def rf_2d(n_points, noise, n_estimators, max_depth, show_trees):
40
+ X, y = generate_2d_regression(n_points, noise)
41
 
42
  model = RandomForestRegressor(
43
  n_estimators=n_estimators,
 
50
  mse = mean_squared_error(y, y_pred)
51
 
52
  fig, ax = plt.subplots(figsize=(5, 4))
53
+ ax.scatter(X, y, s=20, label="Data")
54
+ ax.plot(X, y_pred, linewidth=2, label="RF Avg")
55
+
56
+ # Individual trees
57
+ if show_trees:
58
+ for tree in model.estimators_[:5]:
59
+ ax.plot(X, tree.predict(X), linewidth=1, alpha=0.5)
60
+
61
  ax.set_title("2D Random Forest Regression")
62
  ax.legend()
63
 
64
+ # Feature importance
65
+ imp_fig, imp_ax = plt.subplots(figsize=(4, 3))
66
+ imp_ax.bar(["x"], model.feature_importances_)
67
+ imp_ax.set_title("Feature Importance")
68
+
69
+ return fig, imp_fig, f"MSE: {mse:.4f}"
70
 
71
 
72
  # ------------------------------------------------
73
+ # 3D Random Forest Regression + Rotation
74
  # ------------------------------------------------
 
 
 
 
 
 
75
 
76
+ def rf_3d(n_points, noise, n_estimators, max_depth, rotate):
77
+ X1, X2, X_flat, Z, Z_flat = generate_3d_regression(n_points, noise)
78
 
79
  model = RandomForestRegressor(
80
  n_estimators=n_estimators,
 
86
  Z_pred = model.predict(X_flat).reshape(X1.shape)
87
  mse = mean_squared_error(Z_flat, Z_pred.ravel())
88
 
89
+ frames = []
90
+
91
+ angles = range(0, 360, 10) if rotate else [45]
92
+
93
+ for angle in angles:
94
+ fig = plt.figure(figsize=(5, 4))
95
+ ax = fig.add_subplot(111, projection="3d")
96
+
97
+ idx = np.random.choice(len(Z_flat), min(400, len(Z_flat)), replace=False)
98
+ ax.scatter(X_flat[idx, 0], X_flat[idx, 1], Z_flat[idx], s=8, alpha=0.3)
 
 
99
 
100
+ ax.plot_surface(X1, X2, Z_pred, alpha=0.7)
101
+ ax.view_init(elev=25, azim=angle)
102
+ ax.set_title("3D Random Forest Surface")
103
 
104
+ frames.append(fig)
105
+
106
+ return frames, f"MSE: {mse:.4f}"
107
 
108
 
109
  # ------------------------------------------------
110
+ # Classification Version
111
  # ------------------------------------------------
112
+
113
+ def rf_classification(n_points, noise, n_estimators, max_depth):
114
+ X, y = generate_classification(n_points, noise)
115
+
116
+ model = RandomForestClassifier(
117
+ n_estimators=n_estimators,
118
+ max_depth=max_depth,
119
+ random_state=42
120
  )
121
+ model.fit(X, y)
122
+
123
+ y_pred = model.predict(X)
124
+ acc = accuracy_score(y, y_pred)
125
 
126
+ fig, ax = plt.subplots(figsize=(5, 4))
127
+ scatter = ax.scatter(X[:, 0], X[:, 1], c=y_pred, s=20)
128
+ ax.set_title("Random Forest Classification")
129
+
130
+ # Feature importance
131
+ imp_fig, imp_ax = plt.subplots(figsize=(4, 3))
132
+ imp_ax.bar(["x1", "x2"], model.feature_importances_)
133
+ imp_ax.set_title("Feature Importance")
134
 
135
+ return fig, imp_fig, f"Accuracy: {acc:.4f}"
 
 
136
 
 
 
 
137
 
138
+ # ------------------------------------------------
139
+ # Gradio UI (Auto‑run on input change)
140
+ # ------------------------------------------------
141
+
142
+ with gr.Blocks() as demo:
143
+ gr.Markdown("# 🌲 Random Forest Visualizer (Full Interactive)")
144
+
145
+ mode = gr.Radio(
146
+ ["2D Regression", "3D Regression", "Classification"],
147
+ value="2D Regression",
148
+ label="Mode"
149
+ )
150
+
151
+ n_points = gr.Slider(50, 200, value=100, step=10, label="Data Points")
152
+ noise = gr.Slider(0.0, 5.0, value=1.0, label="Noise")
153
+ n_estimators = gr.Slider(10, 150, value=50, step=10, label="Trees")
154
+ max_depth = gr.Slider(2, 20, value=8, step=1, label="Max Depth")
155
+
156
+ show_trees = gr.Checkbox(label="Show Individual Trees (2D)", value=False)
157
+ rotate = gr.Checkbox(label="Rotate 3D", value=False)
158
 
159
  plot = gr.Plot()
160
+ plot2 = gr.Plot()
161
  metric = gr.Markdown()
162
 
163
+ def run(mode, n_points, noise, n_estimators, max_depth, show_trees, rotate):
164
  if mode == "2D Regression":
165
+ return (*rf_2d(n_points, noise, n_estimators, max_depth, show_trees))
166
+ elif mode == "3D Regression":
167
+ frames, text = rf_3d(n_points, noise, n_estimators, max_depth, rotate)
168
+ return frames[0], None, text
169
  else:
170
+ return (*rf_classification(n_points, noise, n_estimators, max_depth))
171
+
172
+ inputs = [mode, n_points, noise, n_estimators, max_depth, show_trees, rotate]
173
+
174
+ for inp in inputs:
175
+ inp.change(run, inputs=inputs, outputs=[plot, plot2, metric])
176
+
177
+ demo.load(run, inputs=inputs, outputs=[plot, plot2, metric])
178
 
 
 
 
 
 
179
 
180
  demo.launch()