import gradio as gr import numpy as np import numexpr import pandas as pd import time NUMEXPR_CONSTANTS = { 'pi': np.pi, 'PI': np.pi, 'e': np.e, } def get_function(function, xlim=(-1, 1), nsample=100): x = np.linspace(xlim[0], xlim[1], nsample) y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) x = x.reshape(-1, 1) return x, y def get_data_points(function, xlim=(-1, 1), nsample=10, sigma=0, seed=0): num_points_to_generate = 100 if nsample > num_points_to_generate: raise ValueError(f"nsample too large, limit to {num_points_to_generate}") rng = np.random.default_rng(seed) x = rng.uniform(xlim[0], xlim[1], size=num_points_to_generate) x = x[:nsample] x = np.sort(x) rng = np.random.default_rng(seed) noise = sigma * rng.standard_normal(nsample) y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) + noise x = x.reshape(-1, 1) return x, y class Dataset: def __init__( self, mode: str = "generate", function: str = "sin(2 * pi * x)", xmin: float = -1.0, xmax: float = 1.0, nsample: int = 30, sigma: float = 0.0, seed: int = 0, csv_path: str = None, ): self.mode = mode self.function = function self.xmin = xmin self.xmax = xmax self.nsample = nsample self.sigma = sigma self.seed = seed self.csv_path = csv_path self.x, self.y = self._get_data() def _get_data(self): if self.mode == "generate": return get_data_points( function=self.function, xlim=(self.xmin, self.xmax), nsample=self.nsample, sigma=self.sigma, seed=self.seed, ) elif self.mode == "csv": if self.csv_path is None: return np.array([]), np.array([]) df = pd.read_csv(self.csv_path) if df.shape[1] != 2: raise ValueError("CSV file must have exactly two columns") x = df.iloc[:, 0].values.reshape(-1, 1) y = df.iloc[:, 1].values return x, y else: raise ValueError(f"Unknown dataset mode: {self.mode}") def update(self, **kwargs): return Dataset( mode=kwargs.get("mode", self.mode), function=kwargs.get("function", self.function), xmin=kwargs.get("xmin", self.xmin), xmax=kwargs.get("xmax", self.xmax), nsample=kwargs.get("nsample", self.nsample), sigma=kwargs.get("sigma", self.sigma), seed=kwargs.get("seed", self.seed), csv_path=kwargs.get("csv_path", self.csv_path), ) def _safe_hash(self, val: int) -> int | tuple[int, str]: # special handling for -1 (same hash number as -2) if val == -1: return (-1, "special") return val def __hash__(self): return hash( ( self.mode, self.function, self._safe_hash(self.xmin), self._safe_hash(self.xmax), self.nsample, self.sigma, self.seed, self.csv_path, ) ) class DatasetView: def update_mode(self, mode: str, state: gr.State): state = state.update(mode=mode) if mode == "generate": return ( state, gr.update(visible=True), # function gr.update(visible=True), # xmin gr.update(visible=True), # xmax gr.update(visible=True), # sigma gr.update(visible=True), # nsample gr.update(visible=True), # regenerate gr.update(visible=False), # csv upload ) elif mode == "csv": return ( state, gr.update(visible=False), # function gr.update(visible=False), # xmin gr.update(visible=False), # xmax gr.update(visible=False), # sigma gr.update(visible=False), # nsample gr.update(visible=False), # regenerate gr.update(visible=True), # csv upload ) else: raise ValueError(f"Unknown mode: {mode}") def upload_csv(self, file, state): try: state = state.update( mode="csv", csv_path=file.name, ) except Exception as e: gr.Info(f"⚠️ {e}") return state def regenerate_data(self, state: gr.State): seed = int(time.time() * 1000) % (2 ** 32) state = state.update(seed=seed) return state def update_all(self, function, xmin, xmax, sigma, nsample, state): state = state.update( function=function, xmin=xmin, xmax=xmax, sigma=sigma, nsample=nsample, ) return state def build(self, state: gr.State): options = state.value with gr.Column(): mode = gr.Radio( label="Dataset", choices=["generate", "csv"], value="generate", ) function = gr.Textbox( label="Function (in terms of x)", value=options.function, ) with gr.Row(): xmin = gr.Number( label="x min", value=options.xmin, ) xmax = gr.Number( label="x max", value=options.xmax, ) sigma = gr.Number( label="Gaussian noise standard deviation", value=options.sigma, ) nsample = gr.Slider( label="Number of samples", minimum=0, maximum=100, step=1, value=options.nsample, ) regenerate = gr.Button("Regenerate Data") csv_upload = gr.File( label="Upload CSV file", file_types=['.csv'], visible=False, # function mode is default ) mode.change( fn=self.update_mode, inputs=[mode, state], outputs=[state, function, xmin, xmax, sigma, nsample, regenerate, csv_upload], ) # generate mode function.submit( lambda f, s: s.update(function=f), inputs=[function, state], outputs=[state], ) xmin.submit( lambda xmn, s: s.update(xmin=xmn), inputs=[xmin, state], outputs=[state], ) xmax.submit( lambda xmx, s: s.update(xmax=xmx), inputs=[xmax, state], outputs=[state], ) sigma.submit( lambda sig, s: s.update(sigma=sig), inputs=[sigma, state], outputs=[state], ) nsample.change( lambda n, s: s.update(nsample=n), inputs=[nsample, state], outputs=[state], ) regenerate.click( self.update_all, inputs=[function, xmin, xmax, sigma, nsample, state], outputs=[state], ).then( fn=self.regenerate_data, inputs=[state], outputs=[state], ) # csv mode csv_upload.upload( self.upload_csv, inputs=[csv_upload, state], outputs=[state], )