Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier | |
| from sklearn.metrics import mean_squared_error, accuracy_score | |
| # ------------------------------------------------ | |
| # DATA GENERATORS | |
| # ------------------------------------------------ | |
| def generate_2d_regression(n_points, noise): | |
| n_points = int(n_points) | |
| X = np.linspace(0, 10, n_points) | |
| y = 2.5 * X + 5 + np.random.randn(n_points) * noise | |
| return X.reshape(-1, 1), y | |
| def generate_3d_regression(n_points, noise): | |
| n_points = int(n_points) | |
| x1 = np.linspace(0, 10, n_points) | |
| x2 = np.linspace(0, 10, n_points) | |
| X1, X2 = np.meshgrid(x1, x2) | |
| Z = 3 * X1 + 2 * X2 + 10 + np.random.randn(*X1.shape) * noise | |
| X_flat = np.column_stack((X1.ravel(), X2.ravel())) | |
| Z_flat = Z.ravel() | |
| return X1, X2, X_flat, Z_flat | |
| def generate_classification(n_points, noise): | |
| n_points = int(n_points) | |
| X = np.random.randn(n_points, 2) | |
| y = (X[:, 0]**2 + X[:, 1] > 0.5).astype(int) | |
| X += np.random.randn(*X.shape) * noise * 0.1 | |
| return X, y | |
| # ------------------------------------------------ | |
| # 2D RANDOM FOREST REGRESSION | |
| # ------------------------------------------------ | |
| def rf_2d_view(n_points, noise, n_estimators, max_depth): | |
| n_estimators = int(n_estimators) | |
| max_depth = int(max_depth) | |
| X, y = generate_2d_regression(n_points, noise) | |
| rf = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, random_state=42) | |
| rf.fit(X, y) | |
| y_pred = rf.predict(X) | |
| mse = mean_squared_error(y, y_pred) | |
| fig = go.Figure() | |
| fig.add_scatter(x=X.flatten(), y=y, mode="markers", name="Data") | |
| fig.add_scatter(x=X.flatten(), y=y_pred, mode="lines", name="RF Prediction") | |
| fig.update_layout( | |
| title="2D Random Forest Regression", | |
| paper_bgcolor="#0b1e3d", | |
| plot_bgcolor="#0b1e3d", | |
| font=dict(color="white") | |
| ) | |
| return fig, f"MSE: {mse:.4f}" | |
| # ------------------------------------------------ | |
| # 3D RANDOM FOREST REGRESSION (ROTATING) | |
| # ------------------------------------------------ | |
| def rf_3d_view(n_points, noise, n_estimators, max_depth): | |
| n_estimators = int(n_estimators) | |
| max_depth = int(max_depth) | |
| X1, X2, X_flat, Z_flat = generate_3d_regression(n_points, noise) | |
| rf = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, random_state=42) | |
| rf.fit(X_flat, Z_flat) | |
| Z_pred = rf.predict(X_flat).reshape(X1.shape) | |
| mse = mean_squared_error(Z_flat, Z_pred.ravel()) | |
| fig = go.Figure() | |
| fig.add_surface(x=X1, y=X2, z=Z_pred, colorscale="Blues", opacity=0.9) | |
| # Smooth rotation frames | |
| frames = [] | |
| for angle in range(0, 360, 10): | |
| frames.append(go.Frame(layout=dict(scene_camera=dict( | |
| eye=dict( | |
| x=np.cos(np.radians(angle)) * 2, | |
| y=np.sin(np.radians(angle)) * 2, | |
| z=1.2 | |
| ) | |
| )))) | |
| fig.frames = frames | |
| fig.update_layout( | |
| title="3D Random Forest Regression", | |
| scene=dict(bgcolor="#0b1e3d"), | |
| paper_bgcolor="#0b1e3d", | |
| font=dict(color="white"), | |
| updatemenus=[dict( | |
| type="buttons", | |
| showactive=False, | |
| buttons=[dict( | |
| label="Rotate 3D", | |
| method="animate", | |
| args=[None, {"frame": {"duration": 60, "redraw": True}, "fromcurrent": True}] | |
| )] | |
| )] | |
| ) | |
| return fig, f"MSE: {mse:.4f}" | |
| # ------------------------------------------------ | |
| # CLASSIFICATION VIEW | |
| # ------------------------------------------------ | |
| def classification_view(n_points, noise, n_estimators, max_depth): | |
| n_estimators = int(n_estimators) | |
| max_depth = int(max_depth) | |
| X, y = generate_classification(n_points, noise) | |
| rf = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=42) | |
| rf.fit(X, y) | |
| # decision boundary | |
| xx, yy = np.meshgrid( | |
| np.linspace(X[:, 0].min() - 1, X[:, 0].max() + 1, 100), | |
| np.linspace(X[:, 1].min() - 1, X[:, 1].max() + 1, 100) | |
| ) | |
| Z = rf.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape) | |
| acc = accuracy_score(y, rf.predict(X)) | |
| fig = go.Figure() | |
| fig.add_contour(x=xx[0], y=yy[:, 0], z=Z, showscale=False, opacity=0.4, colorscale="Blues") | |
| fig.add_scatter(x=X[:, 0], y=X[:, 1], mode="markers", | |
| marker=dict(color=y, colorscale="Blues"), name="Data") | |
| fig.update_layout( | |
| title="Random Forest Classification Boundary", | |
| paper_bgcolor="#0b1e3d", | |
| plot_bgcolor="#0b1e3d", | |
| font=dict(color="white") | |
| ) | |
| return fig, f"Accuracy: {acc:.4f}" | |
| # ------------------------------------------------ | |
| # GRADIO UI | |
| # ------------------------------------------------ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🌲 Random Forest Learning Dashboard (2D • 3D • Classification)") | |
| n_points = gr.Slider(20, 100, 40, step=1, label="Number of Points") | |
| noise = gr.Slider(0.0, 3.0, 1.0, label="Noise Level") | |
| n_estimators = gr.Slider(10, 100, 50, step=1, label="Number of Trees") | |
| max_depth = gr.Slider(2, 15, 8, step=1, label="Tree Depth") | |
| with gr.Tab("📈 2D Regression"): | |
| plot2d = gr.Plot() | |
| mse2d = gr.Markdown() | |
| with gr.Tab("🌐 3D Regression"): | |
| plot3d = gr.Plot() | |
| mse3d = gr.Markdown() | |
| with gr.Tab("🧩 Classification"): | |
| plot_cls = gr.Plot() | |
| acc_cls = gr.Markdown() | |
| def run_all(n_points, noise, n_estimators, max_depth): | |
| fig2d, mse_text = rf_2d_view(n_points, noise, n_estimators, max_depth) | |
| fig3d, mse3d_text = rf_3d_view(n_points, noise, n_estimators, max_depth) | |
| fig_cls, acc_text = classification_view(n_points, noise, n_estimators, max_depth) | |
| return fig2d, mse_text, fig3d, mse3d_text, fig_cls, acc_text | |
| for inp in [n_points, noise, n_estimators, max_depth]: | |
| inp.change(run_all, | |
| [n_points, noise, n_estimators, max_depth], | |
| [plot2d, mse2d, plot3d, mse3d, plot_cls, acc_cls]) | |
| demo.load(run_all, | |
| [n_points, noise, n_estimators, max_depth], | |
| [plot2d, mse2d, plot3d, mse3d, plot_cls, acc_cls]) | |
| demo.launch(theme=gr.themes.Soft(primary_hue="blue")) | |