import gradio as gr import numpy as np import numexpr import pandas as pd import time NUMEXPR_CONSTANTS = { 'pi': np.pi, 'PI': np.pi, 'e': np.e, } def get_function(function, x1lim, x2lim, nsample=100): x1 = np.linspace(x1lim[0], x1lim[1], nsample) x2 = np.linspace(x2lim[0], x2lim[1], nsample) mesh_x1, mesh_x2 = np.meshgrid(x1, x2) y = numexpr.evaluate( function, local_dict={'x1': mesh_x1.ravel(), 'x2': mesh_x2.ravel(), **NUMEXPR_CONSTANTS} ) y = y.reshape(mesh_x1.shape) return mesh_x1, mesh_x2, y def get_data_points(function, x1lim, x2lim, nsample=10, sigma=0., random_x=False, seed=0): if random_x: rng = np.random.default_rng(seed) x1 = rng.uniform(x1lim[0], x1lim[1], size=nsample) x2 = rng.uniform(x2lim[0], x2lim[1], size=nsample) else: size = int(np.ceil(np.sqrt(nsample))) x1 = np.linspace(x1lim[0], x1lim[1], size) x2 = np.linspace(x2lim[0], x2lim[1], size) x1, x2 = np.meshgrid(x1, x2) x1 = x1.ravel()[:nsample] x2 = x2.ravel()[:nsample] rng = np.random.default_rng(seed) noise = sigma * rng.standard_normal(nsample) y = numexpr.evaluate( function, local_dict={'x1': x1, 'x2': x2, **NUMEXPR_CONSTANTS} ) y += noise X = np.stack([x1, x2], axis=1) return X, y class Dataset: def __init__( self, mode: str = "generate", function: str = "25 * x1 + 30 * x2", x1lim: tuple[float, float] = (-1, 1), x2lim: tuple[float, float] = (-1, 1), nsample: int = 100, sigma: float = 0.1, random_x: bool = False, seed: int = 0, csv_path: str | None = None, ): self.mode = mode self.function = function self.x1lim = x1lim self.x2lim = x2lim self.nsample = nsample self.sigma = sigma self.random_x = random_x self.seed = seed self.csv_path = csv_path self.X, self.y = self._get_data() def get_function(self, nsample: int = 100): return get_function( function=self.function, x1lim=self.x1lim, x2lim=self.x2lim, nsample=nsample, ) def _get_data(self): if self.mode == "generate" or self.csv_path is None: return get_data_points( function=self.function, x1lim=self.x1lim, x2lim=self.x2lim, nsample=self.nsample, sigma=self.sigma, random_x=self.random_x, seed=self.seed, ) elif self.mode == "csv": if self.csv_path is None: raise RuntimeError("Something is wrong") df = pd.read_csv(self.csv_path) if df.shape[1] != 3: raise ValueError("CSV file must have exactly three columns") x = df.iloc[:, :-1].values y = df.iloc[:, -1].values return x, y else: raise ValueError(f"Unknown dataset mode: {self.mode}") def update(self, **kwargs): return Dataset( mode=kwargs.get("mode", self.mode), function=kwargs.get("function", self.function), x1lim=kwargs.get("x1lim", self.x1lim), x2lim=kwargs.get("x2lim", self.x2lim), nsample=kwargs.get("nsample", self.nsample), sigma=kwargs.get("sigma", self.sigma), random_x=kwargs.get("random_x", self.random_x), seed=kwargs.get("seed", self.seed), csv_path=kwargs.get("csv_path", self.csv_path), ) def _safe_hash(self, val: int | float) -> int | float | tuple[int, str]: # special handling for -1 (same hash number as -2) if val == -1: return (-1, "special") return val def __hash__(self): return hash( ( self.mode, self.function, self._safe_hash(self.x1lim[0]), self._safe_hash(self.x1lim[1]), self._safe_hash(self.x2lim[0]), self._safe_hash(self.x2lim[1]), self.nsample, self.sigma, self.random_x, self.seed, self.csv_path, ) ) class DatasetView: def update_mode(self, mode: str, state: gr.State): state = state.update(mode=mode) if mode == "generate": return ( state, gr.update(visible=True), # function gr.update(visible=True), # x1lim gr.update(visible=True), # x2lim gr.update(visible=True), # sigma gr.update(visible=True), # nsample gr.update(visible=True), # regenerate gr.update(visible=False), # csv upload ) elif mode == "csv": return ( state, gr.update(visible=False), # function gr.update(visible=False), # x1lim gr.update(visible=False), # x2lim gr.update(visible=False), # sigma gr.update(visible=False), # nsample gr.update(visible=False), # regenerate gr.update(visible=True), # csv upload ) else: raise ValueError(f"Unknown mode: {mode}") def update_x1lim(self, x1lim_str: str, state: gr.State): try: x1lim = tuple(map(float, x1lim_str.split(","))) if len(x1lim) != 2: raise ValueError("x1lim must have exactly two values") state = state.update(x1lim=x1lim) except Exception as e: gr.Info(f"⚠️ {e}") return state def update_x2lim(self, x2lim_str: str, state: gr.State): try: x2lim = tuple(map(float, x2lim_str.split(","))) if len(x2lim) != 2: raise ValueError("x2lim must have exactly two values") state = state.update(x2lim=x2lim) except Exception as e: gr.Info(f"⚠️ {e}") return state def update_x_selection_method(self, method: str, state: gr.State): random_x = method == "Uniformly sampled" print("Updating random_x to", random_x) state = state.update(random_x=random_x) return state def upload_csv(self, file, state): try: state = state.update( mode="csv", csv_path=file.name, ) except Exception as e: gr.Info(f"⚠️ {e}") return state def regenerate_data(self, state: gr.State): seed = int(time.time() * 1000) % (2 ** 32) state = state.update(seed=seed) return state def update_all( self, function: str, x1lim_str: str, x2lim_str: str, sigma: float, nsample: int, state: gr.State, ): state = state.update(function=function) try: x1lim = tuple(map(float, x1lim_str.split(","))) if len(x1lim) != 2: raise ValueError("x1lim must have exactly two values") state = state.update(x1lim=x1lim) except Exception as e: gr.Info(f"⚠️ {e}") try: x2lim = tuple(map(float, x2lim_str.split(","))) if len(x2lim) != 2: raise ValueError("x2lim must have exactly two values") state = state.update(x2lim=x2lim) except Exception as e: gr.Info(f"⚠️ {e}") state = state.update(sigma=sigma) state = state.update(nsample=nsample) return state def build(self, state: gr.State): options = state.value with gr.Column(): with gr.Row(): mode = gr.Radio( label="Dataset", choices=["generate", "csv"], value="generate", ) with gr.Row(): function = gr.Textbox( label="Function (in terms of x1 and x2)", value=options.function, ) with gr.Row(): x1_textbox = gr.Textbox( label="x1 range", value=f"{options.x1lim[0]}, {options.x1lim[1]}", interactive=True, ) x2_textbox = gr.Textbox( label="x2 range", value=f"{options.x2lim[0]}, {options.x2lim[1]}", interactive=True, ) x_selection_method = gr.Radio( label="How to select x points", choices=["Evenly spaced", "Uniformly sampled"], value="Evenly spaced", ) with gr.Row(): sigma = gr.Number( label="Gaussian noise standard deviation", value=options.sigma, ) nsample = gr.Number( label="Number of points", value=options.nsample, ) regenerate = gr.Button("Regenerate Data") csv_upload = gr.File( label="Upload CSV file - must have columns: (x1, x2, y)", file_types=['.csv'], visible=False, # function mode is default ) mode.change( fn=self.update_mode, inputs=[mode, state], outputs=[state, function, x1_textbox, x2_textbox, sigma, nsample, regenerate, csv_upload], ) # generate mode function.submit( lambda f, s: s.update(function=f), inputs=[function, state], outputs=[state], ) x1_textbox.submit( fn=self.update_x1lim, inputs=[x1_textbox, state], outputs=[state], ) x2_textbox.submit( fn=self.update_x2lim, inputs=[x2_textbox, state], outputs=[state], ) x_selection_method.change( fn=self.update_x_selection_method, inputs=[x_selection_method, state], outputs=[state], ) sigma.submit( lambda sig, s: s.update(sigma=sig), inputs=[sigma, state], outputs=[state], ) nsample.submit( lambda n, s: s.update(nsample=n), inputs=[nsample, state], outputs=[state], ) regenerate.click( fn=self.update_all, inputs=[function, x1_textbox, x2_textbox, sigma, nsample, state], outputs=[state], ).then( fn=self.regenerate_data, inputs=[state], outputs=[state], ) # csv mode csv_upload.upload( self.upload_csv, inputs=[csv_upload, state], outputs=[state], )