Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import numexpr | |
| import pandas as pd | |
| import time | |
| NUMEXPR_CONSTANTS = { | |
| 'pi': np.pi, | |
| 'PI': np.pi, | |
| 'e': np.e, | |
| } | |
| def get_function(function, x1lim, x2lim, nsample=100): | |
| x1 = np.linspace(x1lim[0], x1lim[1], nsample) | |
| x2 = np.linspace(x2lim[0], x2lim[1], nsample) | |
| mesh_x1, mesh_x2 = np.meshgrid(x1, x2) | |
| y = numexpr.evaluate( | |
| function, | |
| local_dict={'x1': mesh_x1.ravel(), 'x2': mesh_x2.ravel(), **NUMEXPR_CONSTANTS} | |
| ) | |
| y = y.reshape(mesh_x1.shape) | |
| return mesh_x1, mesh_x2, y | |
| def get_data_points(function, x1lim, x2lim, nsample=10, sigma=0., random_x=False, seed=0): | |
| if random_x: | |
| rng = np.random.default_rng(seed) | |
| x1 = rng.uniform(x1lim[0], x1lim[1], size=nsample) | |
| x2 = rng.uniform(x2lim[0], x2lim[1], size=nsample) | |
| else: | |
| size = int(np.ceil(np.sqrt(nsample))) | |
| x1 = np.linspace(x1lim[0], x1lim[1], size) | |
| x2 = np.linspace(x2lim[0], x2lim[1], size) | |
| x1, x2 = np.meshgrid(x1, x2) | |
| x1 = x1.ravel()[:nsample] | |
| x2 = x2.ravel()[:nsample] | |
| rng = np.random.default_rng(seed) | |
| noise = sigma * rng.standard_normal(nsample) | |
| y = numexpr.evaluate( | |
| function, | |
| local_dict={'x1': x1, 'x2': x2, **NUMEXPR_CONSTANTS} | |
| ) | |
| y += noise | |
| X = np.stack([x1, x2], axis=1) | |
| return X, y | |
| class Dataset: | |
| def __init__( | |
| self, | |
| mode: str = "generate", | |
| function: str = "25 * x1 + 30 * x2", | |
| x1lim: tuple[float, float] = (-1, 1), | |
| x2lim: tuple[float, float] = (-1, 1), | |
| nsample: int = 100, | |
| sigma: float = 0.1, | |
| random_x: bool = False, | |
| seed: int = 0, | |
| csv_path: str | None = None, | |
| ): | |
| self.mode = mode | |
| self.function = function | |
| self.x1lim = x1lim | |
| self.x2lim = x2lim | |
| self.nsample = nsample | |
| self.sigma = sigma | |
| self.random_x = random_x | |
| self.seed = seed | |
| self.csv_path = csv_path | |
| self.X, self.y = self._get_data() | |
| def get_function(self, nsample: int = 100): | |
| return get_function( | |
| function=self.function, | |
| x1lim=self.x1lim, | |
| x2lim=self.x2lim, | |
| nsample=nsample, | |
| ) | |
| def _get_data(self): | |
| if self.mode == "generate" or self.csv_path is None: | |
| return get_data_points( | |
| function=self.function, | |
| x1lim=self.x1lim, | |
| x2lim=self.x2lim, | |
| nsample=self.nsample, | |
| sigma=self.sigma, | |
| random_x=self.random_x, | |
| seed=self.seed, | |
| ) | |
| elif self.mode == "csv": | |
| if self.csv_path is None: | |
| raise RuntimeError("Something is wrong") | |
| df = pd.read_csv(self.csv_path) | |
| if df.shape[1] != 3: | |
| raise ValueError("CSV file must have exactly three columns") | |
| x = df.iloc[:, :-1].values | |
| y = df.iloc[:, -1].values | |
| return x, y | |
| else: | |
| raise ValueError(f"Unknown dataset mode: {self.mode}") | |
| def update(self, **kwargs): | |
| return Dataset( | |
| mode=kwargs.get("mode", self.mode), | |
| function=kwargs.get("function", self.function), | |
| x1lim=kwargs.get("x1lim", self.x1lim), | |
| x2lim=kwargs.get("x2lim", self.x2lim), | |
| nsample=kwargs.get("nsample", self.nsample), | |
| sigma=kwargs.get("sigma", self.sigma), | |
| random_x=kwargs.get("random_x", self.random_x), | |
| seed=kwargs.get("seed", self.seed), | |
| csv_path=kwargs.get("csv_path", self.csv_path), | |
| ) | |
| def _safe_hash(self, val: int | float) -> int | float | tuple[int, str]: | |
| # special handling for -1 (same hash number as -2) | |
| if val == -1: | |
| return (-1, "special") | |
| return val | |
| def __hash__(self): | |
| return hash( | |
| ( | |
| self.mode, | |
| self.function, | |
| self._safe_hash(self.x1lim[0]), | |
| self._safe_hash(self.x1lim[1]), | |
| self._safe_hash(self.x2lim[0]), | |
| self._safe_hash(self.x2lim[1]), | |
| self.nsample, | |
| self.sigma, | |
| self.random_x, | |
| self.seed, | |
| self.csv_path, | |
| ) | |
| ) | |
| class DatasetView: | |
| def update_mode(self, mode: str, state: gr.State): | |
| state = state.update(mode=mode) | |
| if mode == "generate": | |
| return ( | |
| state, | |
| gr.update(visible=True), # function | |
| gr.update(visible=True), # x1lim | |
| gr.update(visible=True), # x2lim | |
| gr.update(visible=True), # sigma | |
| gr.update(visible=True), # nsample | |
| gr.update(visible=True), # regenerate | |
| gr.update(visible=False), # csv upload | |
| ) | |
| elif mode == "csv": | |
| return ( | |
| state, | |
| gr.update(visible=False), # function | |
| gr.update(visible=False), # x1lim | |
| gr.update(visible=False), # x2lim | |
| gr.update(visible=False), # sigma | |
| gr.update(visible=False), # nsample | |
| gr.update(visible=False), # regenerate | |
| gr.update(visible=True), # csv upload | |
| ) | |
| else: | |
| raise ValueError(f"Unknown mode: {mode}") | |
| def update_x1lim(self, x1lim_str: str, state: gr.State): | |
| try: | |
| x1lim = tuple(map(float, x1lim_str.split(","))) | |
| if len(x1lim) != 2: | |
| raise ValueError("x1lim must have exactly two values") | |
| state = state.update(x1lim=x1lim) | |
| except Exception as e: | |
| gr.Info(f"⚠️ {e}") | |
| return state | |
| def update_x2lim(self, x2lim_str: str, state: gr.State): | |
| try: | |
| x2lim = tuple(map(float, x2lim_str.split(","))) | |
| if len(x2lim) != 2: | |
| raise ValueError("x2lim must have exactly two values") | |
| state = state.update(x2lim=x2lim) | |
| except Exception as e: | |
| gr.Info(f"⚠️ {e}") | |
| return state | |
| def update_x_selection_method(self, method: str, state: gr.State): | |
| random_x = method == "Uniformly sampled" | |
| print("Updating random_x to", random_x) | |
| state = state.update(random_x=random_x) | |
| return state | |
| def upload_csv(self, file, state): | |
| try: | |
| state = state.update( | |
| mode="csv", | |
| csv_path=file.name, | |
| ) | |
| except Exception as e: | |
| gr.Info(f"⚠️ {e}") | |
| return state | |
| def regenerate_data(self, state: gr.State): | |
| seed = int(time.time() * 1000) % (2 ** 32) | |
| state = state.update(seed=seed) | |
| return state | |
| def update_all( | |
| self, | |
| function: str, | |
| x1lim_str: str, | |
| x2lim_str: str, | |
| sigma: float, | |
| nsample: int, | |
| state: gr.State, | |
| ): | |
| state = state.update(function=function) | |
| try: | |
| x1lim = tuple(map(float, x1lim_str.split(","))) | |
| if len(x1lim) != 2: | |
| raise ValueError("x1lim must have exactly two values") | |
| state = state.update(x1lim=x1lim) | |
| except Exception as e: | |
| gr.Info(f"⚠️ {e}") | |
| try: | |
| x2lim = tuple(map(float, x2lim_str.split(","))) | |
| if len(x2lim) != 2: | |
| raise ValueError("x2lim must have exactly two values") | |
| state = state.update(x2lim=x2lim) | |
| except Exception as e: | |
| gr.Info(f"⚠️ {e}") | |
| state = state.update(sigma=sigma) | |
| state = state.update(nsample=nsample) | |
| return state | |
| def build(self, state: gr.State): | |
| options = state.value | |
| with gr.Column(): | |
| with gr.Row(): | |
| mode = gr.Radio( | |
| label="Dataset", | |
| choices=["generate", "csv"], | |
| value="generate", | |
| ) | |
| with gr.Row(): | |
| function = gr.Textbox( | |
| label="Function (in terms of x1 and x2)", | |
| value=options.function, | |
| ) | |
| with gr.Row(): | |
| x1_textbox = gr.Textbox( | |
| label="x1 range", | |
| value=f"{options.x1lim[0]}, {options.x1lim[1]}", | |
| interactive=True, | |
| ) | |
| x2_textbox = gr.Textbox( | |
| label="x2 range", | |
| value=f"{options.x2lim[0]}, {options.x2lim[1]}", | |
| interactive=True, | |
| ) | |
| x_selection_method = gr.Radio( | |
| label="How to select x points", | |
| choices=["Evenly spaced", "Uniformly sampled"], | |
| value="Evenly spaced", | |
| ) | |
| with gr.Row(): | |
| sigma = gr.Number( | |
| label="Gaussian noise standard deviation", | |
| value=options.sigma, | |
| ) | |
| nsample = gr.Number( | |
| label="Number of points", | |
| value=options.nsample, | |
| ) | |
| regenerate = gr.Button("Regenerate Data") | |
| csv_upload = gr.File( | |
| label="Upload CSV file - must have columns: (x1, x2, y)", | |
| file_types=['.csv'], | |
| visible=False, # function mode is default | |
| ) | |
| mode.change( | |
| fn=self.update_mode, | |
| inputs=[mode, state], | |
| outputs=[state, function, x1_textbox, x2_textbox, sigma, nsample, regenerate, csv_upload], | |
| ) | |
| # generate mode | |
| function.submit( | |
| lambda f, s: s.update(function=f), | |
| inputs=[function, state], | |
| outputs=[state], | |
| ) | |
| x1_textbox.submit( | |
| fn=self.update_x1lim, | |
| inputs=[x1_textbox, state], | |
| outputs=[state], | |
| ) | |
| x2_textbox.submit( | |
| fn=self.update_x2lim, | |
| inputs=[x2_textbox, state], | |
| outputs=[state], | |
| ) | |
| x_selection_method.change( | |
| fn=self.update_x_selection_method, | |
| inputs=[x_selection_method, state], | |
| outputs=[state], | |
| ) | |
| sigma.submit( | |
| lambda sig, s: s.update(sigma=sig), | |
| inputs=[sigma, state], | |
| outputs=[state], | |
| ) | |
| nsample.submit( | |
| lambda n, s: s.update(nsample=n), | |
| inputs=[nsample, state], | |
| outputs=[state], | |
| ) | |
| regenerate.click( | |
| fn=self.update_all, | |
| inputs=[function, x1_textbox, x2_textbox, sigma, nsample, state], | |
| outputs=[state], | |
| ).then( | |
| fn=self.regenerate_data, | |
| inputs=[state], | |
| outputs=[state], | |
| ) | |
| # csv mode | |
| csv_upload.upload( | |
| self.upload_csv, | |
| inputs=[csv_upload, state], | |
| outputs=[state], | |
| ) | |