regularization / old /dataset.py
joel-woodfield's picture
Refactor code to separate frontend and backend
770d448
import gradio as gr
import numpy as np
import numexpr
import pandas as pd
import time
NUMEXPR_CONSTANTS = {
'pi': np.pi,
'PI': np.pi,
'e': np.e,
}
def get_function(function, x1lim, x2lim, nsample=100):
x1 = np.linspace(x1lim[0], x1lim[1], nsample)
x2 = np.linspace(x2lim[0], x2lim[1], nsample)
mesh_x1, mesh_x2 = np.meshgrid(x1, x2)
y = numexpr.evaluate(
function,
local_dict={'x1': mesh_x1.ravel(), 'x2': mesh_x2.ravel(), **NUMEXPR_CONSTANTS}
)
y = y.reshape(mesh_x1.shape)
return mesh_x1, mesh_x2, y
def get_data_points(function, x1lim, x2lim, nsample=10, sigma=0., random_x=False, seed=0):
if random_x:
rng = np.random.default_rng(seed)
x1 = rng.uniform(x1lim[0], x1lim[1], size=nsample)
x2 = rng.uniform(x2lim[0], x2lim[1], size=nsample)
else:
size = int(np.ceil(np.sqrt(nsample)))
x1 = np.linspace(x1lim[0], x1lim[1], size)
x2 = np.linspace(x2lim[0], x2lim[1], size)
x1, x2 = np.meshgrid(x1, x2)
x1 = x1.ravel()[:nsample]
x2 = x2.ravel()[:nsample]
rng = np.random.default_rng(seed)
noise = sigma * rng.standard_normal(nsample)
y = numexpr.evaluate(
function,
local_dict={'x1': x1, 'x2': x2, **NUMEXPR_CONSTANTS}
)
y += noise
X = np.stack([x1, x2], axis=1)
return X, y
class Dataset:
def __init__(
self,
mode: str = "generate",
function: str = "25 * x1 + 30 * x2",
x1lim: tuple[float, float] = (-1, 1),
x2lim: tuple[float, float] = (-1, 1),
nsample: int = 100,
sigma: float = 0.1,
random_x: bool = False,
seed: int = 0,
csv_path: str | None = None,
):
self.mode = mode
self.function = function
self.x1lim = x1lim
self.x2lim = x2lim
self.nsample = nsample
self.sigma = sigma
self.random_x = random_x
self.seed = seed
self.csv_path = csv_path
self.X, self.y = self._get_data()
def get_function(self, nsample: int = 100):
return get_function(
function=self.function,
x1lim=self.x1lim,
x2lim=self.x2lim,
nsample=nsample,
)
def _get_data(self):
if self.mode == "generate" or self.csv_path is None:
return get_data_points(
function=self.function,
x1lim=self.x1lim,
x2lim=self.x2lim,
nsample=self.nsample,
sigma=self.sigma,
random_x=self.random_x,
seed=self.seed,
)
elif self.mode == "csv":
if self.csv_path is None:
raise RuntimeError("Something is wrong")
df = pd.read_csv(self.csv_path)
if df.shape[1] != 3:
raise ValueError("CSV file must have exactly three columns")
x = df.iloc[:, :-1].values
y = df.iloc[:, -1].values
return x, y
else:
raise ValueError(f"Unknown dataset mode: {self.mode}")
def update(self, **kwargs):
return Dataset(
mode=kwargs.get("mode", self.mode),
function=kwargs.get("function", self.function),
x1lim=kwargs.get("x1lim", self.x1lim),
x2lim=kwargs.get("x2lim", self.x2lim),
nsample=kwargs.get("nsample", self.nsample),
sigma=kwargs.get("sigma", self.sigma),
random_x=kwargs.get("random_x", self.random_x),
seed=kwargs.get("seed", self.seed),
csv_path=kwargs.get("csv_path", self.csv_path),
)
def _safe_hash(self, val: int | float) -> int | float | tuple[int, str]:
# special handling for -1 (same hash number as -2)
if val == -1:
return (-1, "special")
return val
def __hash__(self):
return hash(
(
self.mode,
self.function,
self._safe_hash(self.x1lim[0]),
self._safe_hash(self.x1lim[1]),
self._safe_hash(self.x2lim[0]),
self._safe_hash(self.x2lim[1]),
self.nsample,
self.sigma,
self.random_x,
self.seed,
self.csv_path,
)
)
class DatasetView:
def update_mode(self, mode: str, state: gr.State):
state = state.update(mode=mode)
if mode == "generate":
return (
state,
gr.update(visible=True), # function
gr.update(visible=True), # x1lim
gr.update(visible=True), # x2lim
gr.update(visible=True), # sigma
gr.update(visible=True), # nsample
gr.update(visible=True), # regenerate
gr.update(visible=False), # csv upload
)
elif mode == "csv":
return (
state,
gr.update(visible=False), # function
gr.update(visible=False), # x1lim
gr.update(visible=False), # x2lim
gr.update(visible=False), # sigma
gr.update(visible=False), # nsample
gr.update(visible=False), # regenerate
gr.update(visible=True), # csv upload
)
else:
raise ValueError(f"Unknown mode: {mode}")
def update_x1lim(self, x1lim_str: str, state: gr.State):
try:
x1lim = tuple(map(float, x1lim_str.split(",")))
if len(x1lim) != 2:
raise ValueError("x1lim must have exactly two values")
state = state.update(x1lim=x1lim)
except Exception as e:
gr.Info(f"⚠️ {e}")
return state
def update_x2lim(self, x2lim_str: str, state: gr.State):
try:
x2lim = tuple(map(float, x2lim_str.split(",")))
if len(x2lim) != 2:
raise ValueError("x2lim must have exactly two values")
state = state.update(x2lim=x2lim)
except Exception as e:
gr.Info(f"⚠️ {e}")
return state
def update_x_selection_method(self, method: str, state: gr.State):
random_x = method == "Uniformly sampled"
print("Updating random_x to", random_x)
state = state.update(random_x=random_x)
return state
def upload_csv(self, file, state):
try:
state = state.update(
mode="csv",
csv_path=file.name,
)
except Exception as e:
gr.Info(f"⚠️ {e}")
return state
def regenerate_data(self, state: gr.State):
seed = int(time.time() * 1000) % (2 ** 32)
state = state.update(seed=seed)
return state
def update_all(
self,
function: str,
x1lim_str: str,
x2lim_str: str,
sigma: float,
nsample: int,
state: gr.State,
):
state = state.update(function=function)
try:
x1lim = tuple(map(float, x1lim_str.split(",")))
if len(x1lim) != 2:
raise ValueError("x1lim must have exactly two values")
state = state.update(x1lim=x1lim)
except Exception as e:
gr.Info(f"⚠️ {e}")
try:
x2lim = tuple(map(float, x2lim_str.split(",")))
if len(x2lim) != 2:
raise ValueError("x2lim must have exactly two values")
state = state.update(x2lim=x2lim)
except Exception as e:
gr.Info(f"⚠️ {e}")
state = state.update(sigma=sigma)
state = state.update(nsample=nsample)
return state
def build(self, state: gr.State):
options = state.value
with gr.Column():
with gr.Row():
mode = gr.Radio(
label="Dataset",
choices=["generate", "csv"],
value="generate",
)
with gr.Row():
function = gr.Textbox(
label="Function (in terms of x1 and x2)",
value=options.function,
)
with gr.Row():
x1_textbox = gr.Textbox(
label="x1 range",
value=f"{options.x1lim[0]}, {options.x1lim[1]}",
interactive=True,
)
x2_textbox = gr.Textbox(
label="x2 range",
value=f"{options.x2lim[0]}, {options.x2lim[1]}",
interactive=True,
)
x_selection_method = gr.Radio(
label="How to select x points",
choices=["Evenly spaced", "Uniformly sampled"],
value="Evenly spaced",
)
with gr.Row():
sigma = gr.Number(
label="Gaussian noise standard deviation",
value=options.sigma,
)
nsample = gr.Number(
label="Number of points",
value=options.nsample,
)
regenerate = gr.Button("Regenerate Data")
csv_upload = gr.File(
label="Upload CSV file - must have columns: (x1, x2, y)",
file_types=['.csv'],
visible=False, # function mode is default
)
mode.change(
fn=self.update_mode,
inputs=[mode, state],
outputs=[state, function, x1_textbox, x2_textbox, sigma, nsample, regenerate, csv_upload],
)
# generate mode
function.submit(
lambda f, s: s.update(function=f),
inputs=[function, state],
outputs=[state],
)
x1_textbox.submit(
fn=self.update_x1lim,
inputs=[x1_textbox, state],
outputs=[state],
)
x2_textbox.submit(
fn=self.update_x2lim,
inputs=[x2_textbox, state],
outputs=[state],
)
x_selection_method.change(
fn=self.update_x_selection_method,
inputs=[x_selection_method, state],
outputs=[state],
)
sigma.submit(
lambda sig, s: s.update(sigma=sig),
inputs=[sigma, state],
outputs=[state],
)
nsample.submit(
lambda n, s: s.update(nsample=n),
inputs=[nsample, state],
outputs=[state],
)
regenerate.click(
fn=self.update_all,
inputs=[function, x1_textbox, x2_textbox, sigma, nsample, state],
outputs=[state],
).then(
fn=self.regenerate_data,
inputs=[state],
outputs=[state],
)
# csv mode
csv_upload.upload(
self.upload_csv,
inputs=[csv_upload, state],
outputs=[state],
)