Spaces:
Sleeping
Sleeping
Joel Woodfield commited on
Commit ·
f78f440
1
Parent(s): ea020c0
Make the data plot into a 3D interactive plotly plot
Browse files- regularization.py +36 -20
regularization.py
CHANGED
|
@@ -3,14 +3,15 @@ from pathlib import Path
|
|
| 3 |
import pickle
|
| 4 |
|
| 5 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
import matplotlib.lines as mlines
|
| 8 |
-
import inspect
|
| 9 |
import numpy as np
|
| 10 |
import pandas as pd
|
| 11 |
-
import io
|
| 12 |
-
from jinja2 import Template
|
| 13 |
from PIL import Image
|
|
|
|
| 14 |
import sklearn
|
| 15 |
from sklearn.linear_model import LogisticRegression
|
| 16 |
from sklearn.svm import LinearSVC
|
|
@@ -265,24 +266,39 @@ class Regularization:
|
|
| 265 |
|
| 266 |
def plot_data(self):
|
| 267 |
# make sure the data is the same as the one used in plot_regularization_contour
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
fig
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
def plot_strength_vs_weight(self):
|
|
|
|
| 286 |
X, y = make_regression(n_samples=200, n_features=2, noise=15, random_state=0)
|
| 287 |
alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)])
|
| 288 |
if self.loss_type == "l2":
|
|
@@ -357,7 +373,7 @@ class Regularization:
|
|
| 357 |
with gr.Tab("Regularization contour"):
|
| 358 |
self.regularization_contour = gr.Image(value=self.plot_regularization_contour(), container=True)
|
| 359 |
with gr.Tab("Data"):
|
| 360 |
-
self.
|
| 361 |
with gr.Tab("Strength vs weight"):
|
| 362 |
self.strength_vs_weight = gr.Image(value=self.plot_strength_vs_weight(), container=True)
|
| 363 |
|
|
|
|
| 3 |
import pickle
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
+
import inspect
|
| 7 |
+
import io
|
| 8 |
+
from jinja2 import Template
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
import matplotlib.lines as mlines
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import pandas as pd
|
|
|
|
|
|
|
| 13 |
from PIL import Image
|
| 14 |
+
import plotly.graph_objects as go
|
| 15 |
import sklearn
|
| 16 |
from sklearn.linear_model import LogisticRegression
|
| 17 |
from sklearn.svm import LinearSVC
|
|
|
|
| 266 |
|
| 267 |
def plot_data(self):
|
| 268 |
# make sure the data is the same as the one used in plot_regularization_contour
|
| 269 |
+
_, _, coef = make_regression(n_samples=200, n_features=2, noise=15, random_state=0, coef=True)
|
| 270 |
+
|
| 271 |
+
x1 = np.linspace(-1, 1, 50)
|
| 272 |
+
x2 = np.linspace(-1, 1, 50)
|
| 273 |
+
mesh_x1, mesh_x2 = np.meshgrid(x1, x2)
|
| 274 |
+
X = np.stack((mesh_x1.ravel(), mesh_x2.ravel()), axis=-1)
|
| 275 |
+
y = X @ coef
|
| 276 |
+
|
| 277 |
+
z = y.reshape(mesh_x1.shape)
|
| 278 |
+
|
| 279 |
+
fig = go.Figure(data=go.Surface(
|
| 280 |
+
z=z,
|
| 281 |
+
x=mesh_x1,
|
| 282 |
+
y=mesh_x2,
|
| 283 |
+
colorscale='Viridis',
|
| 284 |
+
opacity=0.8
|
| 285 |
+
))
|
| 286 |
+
|
| 287 |
+
fig.update_layout(
|
| 288 |
+
title="Data",
|
| 289 |
+
scene={
|
| 290 |
+
"xaxis": {"title": "X1", "nticks": 6},
|
| 291 |
+
"yaxis": {"title": "X2", "nticks": 6},
|
| 292 |
+
"zaxis": {"title": "Y", "nticks": 6},
|
| 293 |
+
"camera": {"eye": {"x": -1.5, "y": -1.5, "z": 1.2}},
|
| 294 |
+
},
|
| 295 |
+
width=800,
|
| 296 |
+
height=600,
|
| 297 |
+
)
|
| 298 |
+
return fig
|
| 299 |
|
| 300 |
def plot_strength_vs_weight(self):
|
| 301 |
+
# make sure the data is the same as the one used in plot_regularization_contour
|
| 302 |
X, y = make_regression(n_samples=200, n_features=2, noise=15, random_state=0)
|
| 303 |
alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)])
|
| 304 |
if self.loss_type == "l2":
|
|
|
|
| 373 |
with gr.Tab("Regularization contour"):
|
| 374 |
self.regularization_contour = gr.Image(value=self.plot_regularization_contour(), container=True)
|
| 375 |
with gr.Tab("Data"):
|
| 376 |
+
self.data_3d_plot = gr.Plot(value=self.plot_data(), container=True)
|
| 377 |
with gr.Tab("Strength vs weight"):
|
| 378 |
self.strength_vs_weight = gr.Image(value=self.plot_strength_vs_weight(), container=True)
|
| 379 |
|