selva1909 commited on
Commit
e123993
·
verified ·
1 Parent(s): df5a253

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -49
app.py CHANGED
@@ -3,6 +3,7 @@ import numpy as np
3
  import matplotlib.pyplot as plt
4
  from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
5
  from sklearn.metrics import mean_squared_error, accuracy_score
 
6
 
7
  # ------------------------------------------------
8
  # Helpers
@@ -33,17 +34,13 @@ def generate_classification(n_points, noise):
33
 
34
 
35
  # ------------------------------------------------
36
- # 2D Random Forest Regression + Trees + Importance
37
  # ------------------------------------------------
38
 
39
  def rf_2d(n_points, noise, n_estimators, max_depth, show_trees):
40
  X, y = generate_2d_regression(n_points, noise)
41
 
42
- model = RandomForestRegressor(
43
- n_estimators=n_estimators,
44
- max_depth=max_depth,
45
- random_state=42
46
- )
47
  model.fit(X, y)
48
  y_pred = model.predict(X)
49
 
@@ -53,7 +50,6 @@ def rf_2d(n_points, noise, n_estimators, max_depth, show_trees):
53
  ax.scatter(X, y, s=20, label="Data")
54
  ax.plot(X, y_pred, linewidth=2, label="RF Avg")
55
 
56
- # Individual trees
57
  if show_trees:
58
  for tree in model.estimators_[:5]:
59
  ax.plot(X, tree.predict(X), linewidth=1, alpha=0.5)
@@ -61,7 +57,6 @@ def rf_2d(n_points, noise, n_estimators, max_depth, show_trees):
61
  ax.set_title("2D Random Forest Regression")
62
  ax.legend()
63
 
64
- # Feature importance
65
  imp_fig, imp_ax = plt.subplots(figsize=(4, 3))
66
  imp_ax.bar(["x"], model.feature_importances_)
67
  imp_ax.set_title("Feature Importance")
@@ -70,64 +65,67 @@ def rf_2d(n_points, noise, n_estimators, max_depth, show_trees):
70
 
71
 
72
  # ------------------------------------------------
73
- # 3D Random Forest Regression + Rotation
74
  # ------------------------------------------------
75
 
76
  def rf_3d(n_points, noise, n_estimators, max_depth, rotate):
77
  X1, X2, X_flat, Z, Z_flat = generate_3d_regression(n_points, noise)
78
 
79
- model = RandomForestRegressor(
80
- n_estimators=n_estimators,
81
- max_depth=max_depth,
82
- random_state=42
83
- )
84
  model.fit(X_flat, Z_flat)
85
 
86
  Z_pred = model.predict(X_flat).reshape(X1.shape)
87
  mse = mean_squared_error(Z_flat, Z_pred.ravel())
88
 
89
- frames = []
 
 
 
 
 
 
90
 
91
- angles = range(0, 360, 10) if rotate else [45]
92
 
93
- for angle in angles:
94
- fig = plt.figure(figsize=(5, 4))
95
- ax = fig.add_subplot(111, projection="3d")
 
 
 
 
 
 
 
96
 
97
- idx = np.random.choice(len(Z_flat), min(400, len(Z_flat)), replace=False)
98
- ax.scatter(X_flat[idx, 0], X_flat[idx, 1], Z_flat[idx], s=8, alpha=0.3)
 
 
99
 
100
- ax.plot_surface(X1, X2, Z_pred, alpha=0.7)
101
- ax.view_init(elev=25, azim=angle)
102
- ax.set_title("3D Random Forest Surface")
103
 
104
- frames.append(fig)
105
-
106
- return frames, f"MSE: {mse:.4f}"
107
 
108
 
109
  # ------------------------------------------------
110
- # Classification Version
111
  # ------------------------------------------------
112
 
113
  def rf_classification(n_points, noise, n_estimators, max_depth):
114
  X, y = generate_classification(n_points, noise)
115
 
116
- model = RandomForestClassifier(
117
- n_estimators=n_estimators,
118
- max_depth=max_depth,
119
- random_state=42
120
- )
121
  model.fit(X, y)
122
 
123
  y_pred = model.predict(X)
124
  acc = accuracy_score(y, y_pred)
125
 
126
  fig, ax = plt.subplots(figsize=(5, 4))
127
- scatter = ax.scatter(X[:, 0], X[:, 1], c=y_pred, s=20)
128
  ax.set_title("Random Forest Classification")
129
 
130
- # Feature importance
131
  imp_fig, imp_ax = plt.subplots(figsize=(4, 3))
132
  imp_ax.bar(["x1", "x2"], model.feature_importances_)
133
  imp_ax.set_title("Feature Importance")
@@ -136,17 +134,13 @@ def rf_classification(n_points, noise, n_estimators, max_depth):
136
 
137
 
138
  # ------------------------------------------------
139
- # Gradio UI (Auto‑run on input change)
140
  # ------------------------------------------------
141
 
142
  with gr.Blocks() as demo:
143
- gr.Markdown("# 🌲 Random Forest Visualizer (Full Interactive)")
144
 
145
- mode = gr.Radio(
146
- ["2D Regression", "3D Regression", "Classification"],
147
- value="2D Regression",
148
- label="Mode"
149
- )
150
 
151
  n_points = gr.Slider(50, 200, value=100, step=10, label="Data Points")
152
  noise = gr.Slider(0.0, 5.0, value=1.0, label="Noise")
@@ -154,29 +148,32 @@ with gr.Blocks() as demo:
154
  max_depth = gr.Slider(2, 20, value=8, step=1, label="Max Depth")
155
 
156
  show_trees = gr.Checkbox(label="Show Individual Trees (2D)", value=False)
157
- rotate = gr.Checkbox(label="Rotate 3D", value=False)
158
 
159
  plot = gr.Plot()
160
  plot2 = gr.Plot()
 
161
  metric = gr.Markdown()
162
 
163
  def run(mode, n_points, noise, n_estimators, max_depth, show_trees, rotate):
164
  if mode == "2D Regression":
165
  fig1, fig2, text = rf_2d(n_points, noise, n_estimators, max_depth, show_trees)
166
- return fig1, fig2, text
 
167
  elif mode == "3D Regression":
168
- frames, text = rf_3d(n_points, noise, n_estimators, max_depth, rotate)
169
- return frames[0], None, text
 
170
  else:
171
  fig1, fig2, text = rf_classification(n_points, noise, n_estimators, max_depth)
172
- return fig1, fig2, text
173
 
174
  inputs = [mode, n_points, noise, n_estimators, max_depth, show_trees, rotate]
175
 
176
  for inp in inputs:
177
- inp.change(run, inputs=inputs, outputs=[plot, plot2, metric])
178
 
179
- demo.load(run, inputs=inputs, outputs=[plot, plot2, metric])
180
 
181
 
182
  demo.launch()
 
3
  import matplotlib.pyplot as plt
4
  from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
5
  from sklearn.metrics import mean_squared_error, accuracy_score
6
+ import imageio, os, tempfile
7
 
8
  # ------------------------------------------------
9
  # Helpers
 
34
 
35
 
36
  # ------------------------------------------------
37
+ # 2D Random Forest Regression
38
  # ------------------------------------------------
39
 
40
  def rf_2d(n_points, noise, n_estimators, max_depth, show_trees):
41
  X, y = generate_2d_regression(n_points, noise)
42
 
43
+ model = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
 
 
 
 
44
  model.fit(X, y)
45
  y_pred = model.predict(X)
46
 
 
50
  ax.scatter(X, y, s=20, label="Data")
51
  ax.plot(X, y_pred, linewidth=2, label="RF Avg")
52
 
 
53
  if show_trees:
54
  for tree in model.estimators_[:5]:
55
  ax.plot(X, tree.predict(X), linewidth=1, alpha=0.5)
 
57
  ax.set_title("2D Random Forest Regression")
58
  ax.legend()
59
 
 
60
  imp_fig, imp_ax = plt.subplots(figsize=(4, 3))
61
  imp_ax.bar(["x"], model.feature_importances_)
62
  imp_ax.set_title("Feature Importance")
 
65
 
66
 
67
  # ------------------------------------------------
68
+ # 3D Random Forest Regression with GIF rotation
69
  # ------------------------------------------------
70
 
71
  def rf_3d(n_points, noise, n_estimators, max_depth, rotate):
72
  X1, X2, X_flat, Z, Z_flat = generate_3d_regression(n_points, noise)
73
 
74
+ model = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
 
 
 
 
75
  model.fit(X_flat, Z_flat)
76
 
77
  Z_pred = model.predict(X_flat).reshape(X1.shape)
78
  mse = mean_squared_error(Z_flat, Z_pred.ravel())
79
 
80
+ # Static plot
81
+ fig = plt.figure(figsize=(5, 4))
82
+ ax = fig.add_subplot(111, projection="3d")
83
+ idx = np.random.choice(len(Z_flat), min(400, len(Z_flat)), replace=False)
84
+ ax.scatter(X_flat[idx, 0], X_flat[idx, 1], Z_flat[idx], s=8, alpha=0.3)
85
+ ax.plot_surface(X1, X2, Z_pred, alpha=0.7)
86
+ ax.set_title("3D Random Forest Surface")
87
 
88
+ gif_path = None
89
 
90
+ # Create rotating GIF if enabled
91
+ if rotate:
92
+ tmpdir = tempfile.mkdtemp()
93
+ images = []
94
+ for angle in range(0, 360, 10):
95
+ fig_rot = plt.figure(figsize=(5, 4))
96
+ ax_rot = fig_rot.add_subplot(111, projection="3d")
97
+ ax_rot.scatter(X_flat[idx, 0], X_flat[idx, 1], Z_flat[idx], s=8, alpha=0.3)
98
+ ax_rot.plot_surface(X1, X2, Z_pred, alpha=0.7)
99
+ ax_rot.view_init(elev=25, azim=angle)
100
 
101
+ frame_path = os.path.join(tmpdir, f"frame_{angle}.png")
102
+ fig_rot.savefig(frame_path)
103
+ plt.close(fig_rot)
104
+ images.append(imageio.imread(frame_path))
105
 
106
+ gif_path = os.path.join(tmpdir, "rotation.gif")
107
+ imageio.mimsave(gif_path, images, duration=0.08)
 
108
 
109
+ return fig, gif_path, f"MSE: {mse:.4f}"
 
 
110
 
111
 
112
  # ------------------------------------------------
113
+ # Classification
114
  # ------------------------------------------------
115
 
116
  def rf_classification(n_points, noise, n_estimators, max_depth):
117
  X, y = generate_classification(n_points, noise)
118
 
119
+ model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
 
 
 
 
120
  model.fit(X, y)
121
 
122
  y_pred = model.predict(X)
123
  acc = accuracy_score(y, y_pred)
124
 
125
  fig, ax = plt.subplots(figsize=(5, 4))
126
+ ax.scatter(X[:, 0], X[:, 1], c=y_pred, s=20)
127
  ax.set_title("Random Forest Classification")
128
 
 
129
  imp_fig, imp_ax = plt.subplots(figsize=(4, 3))
130
  imp_ax.bar(["x1", "x2"], model.feature_importances_)
131
  imp_ax.set_title("Feature Importance")
 
134
 
135
 
136
  # ------------------------------------------------
137
+ # Gradio UI
138
  # ------------------------------------------------
139
 
140
  with gr.Blocks() as demo:
141
+ gr.Markdown("# 🌲 Random Forest Visualizer (Fixed + Rotation)")
142
 
143
+ mode = gr.Radio(["2D Regression", "3D Regression", "Classification"], value="2D Regression", label="Mode")
 
 
 
 
144
 
145
  n_points = gr.Slider(50, 200, value=100, step=10, label="Data Points")
146
  noise = gr.Slider(0.0, 5.0, value=1.0, label="Noise")
 
148
  max_depth = gr.Slider(2, 20, value=8, step=1, label="Max Depth")
149
 
150
  show_trees = gr.Checkbox(label="Show Individual Trees (2D)", value=False)
151
+ rotate = gr.Checkbox(label="Rotate 3D (GIF)", value=False)
152
 
153
  plot = gr.Plot()
154
  plot2 = gr.Plot()
155
+ gif = gr.Image(label="3D Rotation")
156
  metric = gr.Markdown()
157
 
158
  def run(mode, n_points, noise, n_estimators, max_depth, show_trees, rotate):
159
  if mode == "2D Regression":
160
  fig1, fig2, text = rf_2d(n_points, noise, n_estimators, max_depth, show_trees)
161
+ return fig1, fig2, None, text
162
+
163
  elif mode == "3D Regression":
164
+ fig, gif_path, text = rf_3d(n_points, noise, n_estimators, max_depth, rotate)
165
+ return fig, None, gif_path, text
166
+
167
  else:
168
  fig1, fig2, text = rf_classification(n_points, noise, n_estimators, max_depth)
169
+ return fig1, fig2, None, text
170
 
171
  inputs = [mode, n_points, noise, n_estimators, max_depth, show_trees, rotate]
172
 
173
  for inp in inputs:
174
+ inp.change(run, inputs=inputs, outputs=[plot, plot2, gif, metric])
175
 
176
+ demo.load(run, inputs=inputs, outputs=[plot, plot2, gif, metric])
177
 
178
 
179
  demo.launch()