Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import numexpr | |
| import pandas as pd | |
| import time | |
| NUMEXPR_CONSTANTS = { | |
| 'pi': np.pi, | |
| 'PI': np.pi, | |
| 'e': np.e, | |
| } | |
| def get_function(function, xlim=(-1, 1), nsample=100): | |
| x = np.linspace(xlim[0], xlim[1], nsample) | |
| y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) | |
| x = x.reshape(-1, 1) | |
| return x, y | |
| def get_data_points(function, xlim=(-1, 1), nsample=10, sigma=0, seed=0): | |
| num_points_to_generate = 100 | |
| if nsample > num_points_to_generate: | |
| raise ValueError(f"nsample too large, limit to {num_points_to_generate}") | |
| rng = np.random.default_rng(seed) | |
| x = rng.uniform(xlim[0], xlim[1], size=num_points_to_generate) | |
| x = x[:nsample] | |
| x = np.sort(x) | |
| rng = np.random.default_rng(seed) | |
| noise = sigma * rng.standard_normal(nsample) | |
| y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) + noise | |
| x = x.reshape(-1, 1) | |
| return x, y | |
| class Dataset: | |
| def __init__( | |
| self, | |
| mode: str = "generate", | |
| function: str = "sin(2 * pi * x)", | |
| xmin: float = -1.0, | |
| xmax: float = 1.0, | |
| nsample: int = 30, | |
| sigma: float = 0.0, | |
| seed: int = 0, | |
| csv_path: str = None, | |
| ): | |
| self.mode = mode | |
| self.function = function | |
| self.xmin = xmin | |
| self.xmax = xmax | |
| self.nsample = nsample | |
| self.sigma = sigma | |
| self.seed = seed | |
| self.csv_path = csv_path | |
| self.x, self.y = self._get_data() | |
| def _get_data(self): | |
| if self.mode == "generate": | |
| return get_data_points( | |
| function=self.function, | |
| xlim=(self.xmin, self.xmax), | |
| nsample=self.nsample, | |
| sigma=self.sigma, | |
| seed=self.seed, | |
| ) | |
| elif self.mode == "csv": | |
| if self.csv_path is None: | |
| return np.array([]), np.array([]) | |
| df = pd.read_csv(self.csv_path) | |
| if df.shape[1] != 2: | |
| raise ValueError("CSV file must have exactly two columns") | |
| x = df.iloc[:, 0].values.reshape(-1, 1) | |
| y = df.iloc[:, 1].values | |
| return x, y | |
| else: | |
| raise ValueError(f"Unknown dataset mode: {self.mode}") | |
| def update(self, **kwargs): | |
| return Dataset( | |
| mode=kwargs.get("mode", self.mode), | |
| function=kwargs.get("function", self.function), | |
| xmin=kwargs.get("xmin", self.xmin), | |
| xmax=kwargs.get("xmax", self.xmax), | |
| nsample=kwargs.get("nsample", self.nsample), | |
| sigma=kwargs.get("sigma", self.sigma), | |
| seed=kwargs.get("seed", self.seed), | |
| csv_path=kwargs.get("csv_path", self.csv_path), | |
| ) | |
| def _safe_hash(self, val: int) -> int | tuple[int, str]: | |
| # special handling for -1 (same hash number as -2) | |
| if val == -1: | |
| return (-1, "special") | |
| return val | |
| def __hash__(self): | |
| return hash( | |
| ( | |
| self.mode, | |
| self.function, | |
| self._safe_hash(self.xmin), | |
| self._safe_hash(self.xmax), | |
| self.nsample, | |
| self.sigma, | |
| self.seed, | |
| self.csv_path, | |
| ) | |
| ) | |
| class DatasetView: | |
| def update_mode(self, mode: str, state: gr.State): | |
| state = state.update(mode=mode) | |
| if mode == "generate": | |
| return ( | |
| state, | |
| gr.update(visible=True), # function | |
| gr.update(visible=True), # xmin | |
| gr.update(visible=True), # xmax | |
| gr.update(visible=True), # sigma | |
| gr.update(visible=True), # nsample | |
| gr.update(visible=True), # regenerate | |
| gr.update(visible=False), # csv upload | |
| ) | |
| elif mode == "csv": | |
| return ( | |
| state, | |
| gr.update(visible=False), # function | |
| gr.update(visible=False), # xmin | |
| gr.update(visible=False), # xmax | |
| gr.update(visible=False), # sigma | |
| gr.update(visible=False), # nsample | |
| gr.update(visible=False), # regenerate | |
| gr.update(visible=True), # csv upload | |
| ) | |
| else: | |
| raise ValueError(f"Unknown mode: {mode}") | |
| def upload_csv(self, file, state): | |
| try: | |
| state = state.update( | |
| mode="csv", | |
| csv_path=file.name, | |
| ) | |
| except Exception as e: | |
| gr.Info(f"⚠️ {e}") | |
| return state | |
| def regenerate_data(self, state: gr.State): | |
| seed = int(time.time() * 1000) % (2 ** 32) | |
| state = state.update(seed=seed) | |
| return state | |
| def update_all(self, function, xmin, xmax, sigma, nsample, state): | |
| state = state.update( | |
| function=function, | |
| xmin=xmin, | |
| xmax=xmax, | |
| sigma=sigma, | |
| nsample=nsample, | |
| ) | |
| return state | |
| def build(self, state: gr.State): | |
| options = state.value | |
| with gr.Column(): | |
| mode = gr.Radio( | |
| label="Dataset", | |
| choices=["generate", "csv"], | |
| value="generate", | |
| ) | |
| function = gr.Textbox( | |
| label="Function (in terms of x)", | |
| value=options.function, | |
| ) | |
| with gr.Row(): | |
| xmin = gr.Number( | |
| label="x min", | |
| value=options.xmin, | |
| ) | |
| xmax = gr.Number( | |
| label="x max", | |
| value=options.xmax, | |
| ) | |
| sigma = gr.Number( | |
| label="Gaussian noise standard deviation", | |
| value=options.sigma, | |
| ) | |
| nsample = gr.Slider( | |
| label="Number of samples", | |
| minimum=0, | |
| maximum=100, | |
| step=1, | |
| value=options.nsample, | |
| ) | |
| regenerate = gr.Button("Regenerate Data") | |
| csv_upload = gr.File( | |
| label="Upload CSV file", | |
| file_types=['.csv'], | |
| visible=False, # function mode is default | |
| ) | |
| mode.change( | |
| fn=self.update_mode, | |
| inputs=[mode, state], | |
| outputs=[state, function, xmin, xmax, sigma, nsample, regenerate, csv_upload], | |
| ) | |
| # generate mode | |
| function.submit( | |
| lambda f, s: s.update(function=f), | |
| inputs=[function, state], | |
| outputs=[state], | |
| ) | |
| xmin.submit( | |
| lambda xmn, s: s.update(xmin=xmn), | |
| inputs=[xmin, state], | |
| outputs=[state], | |
| ) | |
| xmax.submit( | |
| lambda xmx, s: s.update(xmax=xmx), | |
| inputs=[xmax, state], | |
| outputs=[state], | |
| ) | |
| sigma.submit( | |
| lambda sig, s: s.update(sigma=sig), | |
| inputs=[sigma, state], | |
| outputs=[state], | |
| ) | |
| nsample.change( | |
| lambda n, s: s.update(nsample=n), | |
| inputs=[nsample, state], | |
| outputs=[state], | |
| ) | |
| regenerate.click( | |
| self.update_all, | |
| inputs=[function, xmin, xmax, sigma, nsample, state], | |
| outputs=[state], | |
| ).then( | |
| fn=self.regenerate_data, | |
| inputs=[state], | |
| outputs=[state], | |
| ) | |
| # csv mode | |
| csv_upload.upload( | |
| self.upload_csv, | |
| inputs=[csv_upload, state], | |
| outputs=[state], | |
| ) | |