gp_visualizer / old_code /dataset.py
joel-woodfield's picture
Refactor file structure
03e72c7
import gradio as gr
import numpy as np
import numexpr
import pandas as pd
import time
NUMEXPR_CONSTANTS = {
'pi': np.pi,
'PI': np.pi,
'e': np.e,
}
def get_function(function, xlim=(-1, 1), nsample=100):
x = np.linspace(xlim[0], xlim[1], nsample)
y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS})
x = x.reshape(-1, 1)
return x, y
def get_data_points(function, xlim=(-1, 1), nsample=10, sigma=0, seed=0):
num_points_to_generate = 100
if nsample > num_points_to_generate:
raise ValueError(f"nsample too large, limit to {num_points_to_generate}")
rng = np.random.default_rng(seed)
x = rng.uniform(xlim[0], xlim[1], size=num_points_to_generate)
x = x[:nsample]
x = np.sort(x)
rng = np.random.default_rng(seed)
noise = sigma * rng.standard_normal(nsample)
y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) + noise
x = x.reshape(-1, 1)
return x, y
class Dataset:
def __init__(
self,
mode: str = "generate",
function: str = "sin(2 * pi * x)",
xmin: float = -1.0,
xmax: float = 1.0,
nsample: int = 30,
sigma: float = 0.0,
seed: int = 0,
csv_path: str = None,
):
self.mode = mode
self.function = function
self.xmin = xmin
self.xmax = xmax
self.nsample = nsample
self.sigma = sigma
self.seed = seed
self.csv_path = csv_path
self.x, self.y = self._get_data()
def _get_data(self):
if self.mode == "generate":
return get_data_points(
function=self.function,
xlim=(self.xmin, self.xmax),
nsample=self.nsample,
sigma=self.sigma,
seed=self.seed,
)
elif self.mode == "csv":
if self.csv_path is None:
return np.array([]), np.array([])
df = pd.read_csv(self.csv_path)
if df.shape[1] != 2:
raise ValueError("CSV file must have exactly two columns")
x = df.iloc[:, 0].values.reshape(-1, 1)
y = df.iloc[:, 1].values
return x, y
else:
raise ValueError(f"Unknown dataset mode: {self.mode}")
def update(self, **kwargs):
return Dataset(
mode=kwargs.get("mode", self.mode),
function=kwargs.get("function", self.function),
xmin=kwargs.get("xmin", self.xmin),
xmax=kwargs.get("xmax", self.xmax),
nsample=kwargs.get("nsample", self.nsample),
sigma=kwargs.get("sigma", self.sigma),
seed=kwargs.get("seed", self.seed),
csv_path=kwargs.get("csv_path", self.csv_path),
)
def _safe_hash(self, val: int) -> int | tuple[int, str]:
# special handling for -1 (same hash number as -2)
if val == -1:
return (-1, "special")
return val
def __hash__(self):
return hash(
(
self.mode,
self.function,
self._safe_hash(self.xmin),
self._safe_hash(self.xmax),
self.nsample,
self.sigma,
self.seed,
self.csv_path,
)
)
class DatasetView:
def update_mode(self, mode: str, state: gr.State):
state = state.update(mode=mode)
if mode == "generate":
return (
state,
gr.update(visible=True), # function
gr.update(visible=True), # xmin
gr.update(visible=True), # xmax
gr.update(visible=True), # sigma
gr.update(visible=True), # nsample
gr.update(visible=True), # regenerate
gr.update(visible=False), # csv upload
)
elif mode == "csv":
return (
state,
gr.update(visible=False), # function
gr.update(visible=False), # xmin
gr.update(visible=False), # xmax
gr.update(visible=False), # sigma
gr.update(visible=False), # nsample
gr.update(visible=False), # regenerate
gr.update(visible=True), # csv upload
)
else:
raise ValueError(f"Unknown mode: {mode}")
def upload_csv(self, file, state):
try:
state = state.update(
mode="csv",
csv_path=file.name,
)
except Exception as e:
gr.Info(f"⚠️ {e}")
return state
def regenerate_data(self, state: gr.State):
seed = int(time.time() * 1000) % (2 ** 32)
state = state.update(seed=seed)
return state
def update_all(self, function, xmin, xmax, sigma, nsample, state):
state = state.update(
function=function,
xmin=xmin,
xmax=xmax,
sigma=sigma,
nsample=nsample,
)
return state
def build(self, state: gr.State):
options = state.value
with gr.Column():
mode = gr.Radio(
label="Dataset",
choices=["generate", "csv"],
value="generate",
)
function = gr.Textbox(
label="Function (in terms of x)",
value=options.function,
)
with gr.Row():
xmin = gr.Number(
label="x min",
value=options.xmin,
)
xmax = gr.Number(
label="x max",
value=options.xmax,
)
sigma = gr.Number(
label="Gaussian noise standard deviation",
value=options.sigma,
)
nsample = gr.Slider(
label="Number of samples",
minimum=0,
maximum=100,
step=1,
value=options.nsample,
)
regenerate = gr.Button("Regenerate Data")
csv_upload = gr.File(
label="Upload CSV file",
file_types=['.csv'],
visible=False, # function mode is default
)
mode.change(
fn=self.update_mode,
inputs=[mode, state],
outputs=[state, function, xmin, xmax, sigma, nsample, regenerate, csv_upload],
)
# generate mode
function.submit(
lambda f, s: s.update(function=f),
inputs=[function, state],
outputs=[state],
)
xmin.submit(
lambda xmn, s: s.update(xmin=xmn),
inputs=[xmin, state],
outputs=[state],
)
xmax.submit(
lambda xmx, s: s.update(xmax=xmx),
inputs=[xmax, state],
outputs=[state],
)
sigma.submit(
lambda sig, s: s.update(sigma=sig),
inputs=[sigma, state],
outputs=[state],
)
nsample.change(
lambda n, s: s.update(nsample=n),
inputs=[nsample, state],
outputs=[state],
)
regenerate.click(
self.update_all,
inputs=[function, xmin, xmax, sigma, nsample, state],
outputs=[state],
).then(
fn=self.regenerate_data,
inputs=[state],
outputs=[state],
)
# csv mode
csv_upload.upload(
self.upload_csv,
inputs=[csv_upload, state],
outputs=[state],
)