Joel Woodfield commited on
Commit
f78f440
·
1 Parent(s): ea020c0

Make the data plot into a 3D interactive plotly plot

Browse files
Files changed (1) hide show
  1. regularization.py +36 -20
regularization.py CHANGED
@@ -3,14 +3,15 @@ from pathlib import Path
3
  import pickle
4
 
5
  import gradio as gr
 
 
 
6
  import matplotlib.pyplot as plt
7
  import matplotlib.lines as mlines
8
- import inspect
9
  import numpy as np
10
  import pandas as pd
11
- import io
12
- from jinja2 import Template
13
  from PIL import Image
 
14
  import sklearn
15
  from sklearn.linear_model import LogisticRegression
16
  from sklearn.svm import LinearSVC
@@ -265,24 +266,39 @@ class Regularization:
265
 
266
  def plot_data(self):
267
  # make sure the data is the same as the one used in plot_regularization_contour
268
- X, y = make_regression(n_samples=200, n_features=2, noise=15, random_state=0)
269
- fig, ax = plt.subplots(figsize=(8, 8))
270
-
271
- # plot data points
272
- ax.set_xlabel("X1")
273
- ax.set_ylabel("X2")
274
- sc = ax.scatter(X[:, 0], X[:, 1], c=y, cmap="viridis")
275
- fig.colorbar(sc, ax=ax, label="y")
276
-
277
- buf = io.BytesIO()
278
- fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
279
- plt.close(fig)
280
- buf.seek(0)
281
- img = Image.open(buf)
282
-
283
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  def plot_strength_vs_weight(self):
 
286
  X, y = make_regression(n_samples=200, n_features=2, noise=15, random_state=0)
287
  alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)])
288
  if self.loss_type == "l2":
@@ -357,7 +373,7 @@ class Regularization:
357
  with gr.Tab("Regularization contour"):
358
  self.regularization_contour = gr.Image(value=self.plot_regularization_contour(), container=True)
359
  with gr.Tab("Data"):
360
- self.data_image = gr.Image(value=self.plot_data(), container=True)
361
  with gr.Tab("Strength vs weight"):
362
  self.strength_vs_weight = gr.Image(value=self.plot_strength_vs_weight(), container=True)
363
 
 
3
  import pickle
4
 
5
  import gradio as gr
6
+ import inspect
7
+ import io
8
+ from jinja2 import Template
9
  import matplotlib.pyplot as plt
10
  import matplotlib.lines as mlines
 
11
  import numpy as np
12
  import pandas as pd
 
 
13
  from PIL import Image
14
+ import plotly.graph_objects as go
15
  import sklearn
16
  from sklearn.linear_model import LogisticRegression
17
  from sklearn.svm import LinearSVC
 
266
 
267
  def plot_data(self):
268
  # make sure the data is the same as the one used in plot_regularization_contour
269
+ _, _, coef = make_regression(n_samples=200, n_features=2, noise=15, random_state=0, coef=True)
270
+
271
+ x1 = np.linspace(-1, 1, 50)
272
+ x2 = np.linspace(-1, 1, 50)
273
+ mesh_x1, mesh_x2 = np.meshgrid(x1, x2)
274
+ X = np.stack((mesh_x1.ravel(), mesh_x2.ravel()), axis=-1)
275
+ y = X @ coef
276
+
277
+ z = y.reshape(mesh_x1.shape)
278
+
279
+ fig = go.Figure(data=go.Surface(
280
+ z=z,
281
+ x=mesh_x1,
282
+ y=mesh_x2,
283
+ colorscale='Viridis',
284
+ opacity=0.8
285
+ ))
286
+
287
+ fig.update_layout(
288
+ title="Data",
289
+ scene={
290
+ "xaxis": {"title": "X1", "nticks": 6},
291
+ "yaxis": {"title": "X2", "nticks": 6},
292
+ "zaxis": {"title": "Y", "nticks": 6},
293
+ "camera": {"eye": {"x": -1.5, "y": -1.5, "z": 1.2}},
294
+ },
295
+ width=800,
296
+ height=600,
297
+ )
298
+ return fig
299
 
300
  def plot_strength_vs_weight(self):
301
+ # make sure the data is the same as the one used in plot_regularization_contour
302
  X, y = make_regression(n_samples=200, n_features=2, noise=15, random_state=0)
303
  alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)])
304
  if self.loss_type == "l2":
 
373
  with gr.Tab("Regularization contour"):
374
  self.regularization_contour = gr.Image(value=self.plot_regularization_contour(), container=True)
375
  with gr.Tab("Data"):
376
+ self.data_3d_plot = gr.Plot(value=self.plot_data(), container=True)
377
  with gr.Tab("Strength vs weight"):
378
  self.strength_vs_weight = gr.Image(value=self.plot_strength_vs_weight(), container=True)
379