joel-woodfield commited on
Commit
cfcc7b6
·
1 Parent(s): d25398a

Add flexible dataset controls

Browse files
Files changed (1) hide show
  1. dataset.py +349 -0
dataset.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import numexpr
4
+ import pandas as pd
5
+ import time
6
+
7
+
8
+ NUMEXPR_CONSTANTS = {
9
+ 'pi': np.pi,
10
+ 'PI': np.pi,
11
+ 'e': np.e,
12
+ }
13
+
14
+
15
+ def get_function(function, x1lim, x2lim, nsample=100):
16
+ x1 = np.linspace(x1lim[0], x1lim[1], nsample)
17
+ x2 = np.linspace(x2lim[0], x2lim[1], nsample)
18
+ mesh_x1, mesh_x2 = np.meshgrid(x1, x2)
19
+
20
+ y = numexpr.evaluate(
21
+ function,
22
+ local_dict={'x1': mesh_x1, 'x2': mesh_x2, **NUMEXPR_CONSTANTS}
23
+ )
24
+
25
+ X = np.stack([mesh_x1.ravel(), mesh_x2.ravel()], axis=1)
26
+ y = y.ravel()
27
+ return X, y
28
+
29
+
30
+ def get_data_points(function, x1lim, x2lim, nsample=10, sigma=0., seed=0):
31
+ num_points_to_generate = 100
32
+ if nsample > num_points_to_generate:
33
+ raise ValueError(f"nsample too large, limit to {num_points_to_generate}")
34
+
35
+ rng = np.random.default_rng(seed)
36
+ x1 = rng.uniform(x1lim[0], x1lim[1], size=num_points_to_generate)
37
+ x1 = x1[:nsample]
38
+ # Not sure why I put sorting here...
39
+ # x1 = np.sort(x1)
40
+
41
+ x2 = rng.uniform(x2lim[0], x2lim[1], size=num_points_to_generate)
42
+ x2 = x2[:nsample]
43
+ # Not sure why I put sorting here...
44
+ # x2 = np.sort(x2)
45
+
46
+ rng = np.random.default_rng(seed)
47
+ noise = sigma * rng.standard_normal(nsample)
48
+ y = numexpr.evaluate(
49
+ function,
50
+ local_dict={'x1': x1, 'x2': x2, **NUMEXPR_CONSTANTS}
51
+ )
52
+ y += noise
53
+
54
+ X = np.stack([x1, x2], axis=1)
55
+ return X, y
56
+
57
+
58
+ class Dataset:
59
+ def __init__(
60
+ self,
61
+ mode: str = "generate",
62
+ function: str = "25 * x1 + 50 * x2",
63
+ x1lim: tuple[float, float] = (-1, 1),
64
+ x2lim: tuple[float, float] = (-1, 1),
65
+ nsample: int = 100,
66
+ sigma: float = 0.0,
67
+ seed: int = 0,
68
+ csv_path: str | None = None,
69
+ ):
70
+ self.mode = mode
71
+
72
+ self.function = function
73
+ self.x1lim = x1lim
74
+ self.x2lim = x2lim
75
+ self.nsample = nsample
76
+ self.sigma = sigma
77
+ self.seed = seed
78
+
79
+ self.csv_path = csv_path
80
+
81
+ self.X, self.y = self._get_data()
82
+
83
+ def get_function(self, nsample: int = 100):
84
+ return get_function(
85
+ function=self.function,
86
+ x1lim=self.x1lim,
87
+ x2lim=self.x2lim,
88
+ nsample=nsample,
89
+ )
90
+
91
+ def _get_data(self):
92
+ if self.mode == "generate":
93
+ return get_data_points(
94
+ function=self.function,
95
+ x1lim=self.x1lim,
96
+ x2lim=self.x2lim,
97
+ nsample=self.nsample,
98
+ sigma=self.sigma,
99
+ seed=self.seed,
100
+ )
101
+
102
+ elif self.mode == "csv":
103
+ if self.csv_path is None:
104
+ return np.array([]), np.array([])
105
+
106
+ df = pd.read_csv(self.csv_path)
107
+ if df.shape[1] != 2:
108
+ raise ValueError("CSV file must have exactly two columns")
109
+
110
+ x = df.iloc[:, 0].values.reshape(-1, 1)
111
+ y = df.iloc[:, 1].values
112
+ return x, y
113
+
114
+ else:
115
+ raise ValueError(f"Unknown dataset mode: {self.mode}")
116
+
117
+ def update(self, **kwargs):
118
+ return Dataset(
119
+ mode=kwargs.get("mode", self.mode),
120
+ function=kwargs.get("function", self.function),
121
+ x1lim=kwargs.get("x1lim", self.x1lim),
122
+ x2lim=kwargs.get("x2lim", self.x2lim),
123
+ nsample=kwargs.get("nsample", self.nsample),
124
+ sigma=kwargs.get("sigma", self.sigma),
125
+ seed=kwargs.get("seed", self.seed),
126
+ csv_path=kwargs.get("csv_path", self.csv_path),
127
+ )
128
+
129
+ def _safe_hash(self, val: int | float) -> int | float | tuple[int, str]:
130
+ # special handling for -1 (same hash number as -2)
131
+ if val == -1:
132
+ return (-1, "special")
133
+ return val
134
+
135
+ def __hash__(self):
136
+ return hash(
137
+ (
138
+ self.mode,
139
+ self.function,
140
+ self._safe_hash(self.x1lim[0]),
141
+ self._safe_hash(self.x1lim[1]),
142
+ self._safe_hash(self.x2lim[0]),
143
+ self._safe_hash(self.x2lim[1]),
144
+ self.nsample,
145
+ self.sigma,
146
+ self.seed,
147
+ self.csv_path,
148
+ )
149
+ )
150
+
151
+
152
+ class DatasetView:
153
+ def update_mode(self, mode: str, state: gr.State):
154
+ state = state.update(mode=mode)
155
+
156
+ if mode == "generate":
157
+ return (
158
+ state,
159
+ gr.update(visible=True), # function
160
+ gr.update(visible=True), # x1lim
161
+ gr.update(visible=True), # x2lim
162
+ gr.update(visible=True), # sigma
163
+ gr.update(visible=True), # nsample
164
+ gr.update(visible=True), # regenerate
165
+ gr.update(visible=False), # csv upload
166
+ )
167
+ elif mode == "csv":
168
+ return (
169
+ state,
170
+ gr.update(visible=False), # function
171
+ gr.update(visible=False), # x1lim
172
+ gr.update(visible=False), # x2lim
173
+ gr.update(visible=False), # sigma
174
+ gr.update(visible=False), # nsample
175
+ gr.update(visible=False), # regenerate
176
+ gr.update(visible=True), # csv upload
177
+ )
178
+ else:
179
+ raise ValueError(f"Unknown mode: {mode}")
180
+
181
+ def update_x1lim(self, x1lim_str: str, state: gr.State):
182
+ try:
183
+ x1lim = tuple(map(float, x1lim_str.split(",")))
184
+ if len(x1lim) != 2:
185
+ raise ValueError("x1lim must have exactly two values")
186
+ state = state.update(x1lim=x1lim)
187
+
188
+ except Exception as e:
189
+ gr.Info(f"⚠️ {e}")
190
+
191
+ return state
192
+
193
+ def update_x2lim(self, x2lim_str: str, state: gr.State):
194
+ try:
195
+ x2lim = tuple(map(float, x2lim_str.split(",")))
196
+ if len(x2lim) != 2:
197
+ raise ValueError("x2lim must have exactly two values")
198
+ state = state.update(x2lim=x2lim)
199
+
200
+ except Exception as e:
201
+ gr.Info(f"⚠️ {e}")
202
+
203
+ return state
204
+
205
+ def upload_csv(self, file, state):
206
+ try:
207
+ state = state.update(
208
+ mode="csv",
209
+ csv_path=file.name,
210
+ )
211
+
212
+ except Exception as e:
213
+ gr.Info(f"⚠️ {e}")
214
+
215
+ return state
216
+
217
+ def regenerate_data(self, state: gr.State):
218
+ seed = int(time.time() * 1000) % (2 ** 32)
219
+ state = state.update(seed=seed)
220
+ return state
221
+
222
+ def update_all(
223
+ self,
224
+ function: str,
225
+ x1lim_str: str,
226
+ x2lim_str: str,
227
+ sigma: float,
228
+ nsample: int,
229
+ state: gr.State,
230
+ ):
231
+ state = state.update(function=function)
232
+
233
+ try:
234
+ x1lim = tuple(map(float, x1lim_str.split(",")))
235
+ if len(x1lim) != 2:
236
+ raise ValueError("x1lim must have exactly two values")
237
+ state = state.update(x1lim=x1lim)
238
+
239
+ except Exception as e:
240
+ gr.Info(f"⚠️ {e}")
241
+
242
+ try:
243
+ x2lim = tuple(map(float, x2lim_str.split(",")))
244
+ if len(x2lim) != 2:
245
+ raise ValueError("x2lim must have exactly two values")
246
+ state = state.update(x2lim=x2lim)
247
+
248
+ except Exception as e:
249
+ gr.Info(f"⚠️ {e}")
250
+
251
+ state = state.update(sigma=sigma)
252
+ state = state.update(nsample=nsample)
253
+
254
+ return state
255
+
256
+ def build(self, state: gr.State):
257
+ options = state.value
258
+
259
+ with gr.Column():
260
+ mode = gr.Radio(
261
+ label="Dataset",
262
+ choices=["generate", "csv"],
263
+ value="generate",
264
+ )
265
+
266
+ function = gr.Textbox(
267
+ label="Function (in terms of x1 and x2)",
268
+ value=options.function,
269
+ )
270
+
271
+ with gr.Row():
272
+ x1_textbox = gr.Textbox(
273
+ label="x1 range",
274
+ value=f"{options.x1lim[0]}, {options.x1lim[1]}",
275
+ interactive=True,
276
+ )
277
+ x2_textbox = gr.Textbox(
278
+ label="x2 range",
279
+ value=f"{options.x2lim[0]}, {options.x2lim[1]}",
280
+ interactive=True,
281
+ )
282
+
283
+ with gr.Row():
284
+ sigma = gr.Number(
285
+ label="Gaussian noise standard deviation",
286
+ value=options.sigma,
287
+ )
288
+ nsample = gr.Number(
289
+ label="Number of samples",
290
+ value=options.nsample,
291
+ )
292
+ regenerate = gr.Button("Regenerate Data")
293
+
294
+ csv_upload = gr.File(
295
+ label="Upload CSV file",
296
+ file_types=['.csv'],
297
+ visible=False, # function mode is default
298
+ )
299
+
300
+ mode.change(
301
+ fn=self.update_mode,
302
+ inputs=[mode, state],
303
+ outputs=[state, function, x1_textbox, x2_textbox, sigma, nsample, regenerate, csv_upload],
304
+ )
305
+
306
+ # generate mode
307
+ function.submit(
308
+ lambda f, s: s.update(function=f),
309
+ inputs=[function, state],
310
+ outputs=[state],
311
+ )
312
+ x1_textbox.submit(
313
+ fn=self.update_x1lim,
314
+ inputs=[x1_textbox, state],
315
+ outputs=[state],
316
+ )
317
+ x2_textbox.submit(
318
+ fn=self.update_x2lim,
319
+ inputs=[x2_textbox, state],
320
+ outputs=[state],
321
+ )
322
+ sigma.submit(
323
+ lambda sig, s: s.update(sigma=sig),
324
+ inputs=[sigma, state],
325
+ outputs=[state],
326
+ )
327
+ nsample.submit(
328
+ lambda n, s: s.update(nsample=n),
329
+ inputs=[nsample, state],
330
+ outputs=[state],
331
+ )
332
+ regenerate.click(
333
+ fn=self.update_all,
334
+ inputs=[function, x1_textbox, x2_textbox, sigma, nsample, state],
335
+ outputs=[state],
336
+ ).then(
337
+ fn=self.regenerate_data,
338
+ inputs=[state],
339
+ outputs=[state],
340
+ )
341
+
342
+ # csv mode
343
+ csv_upload.upload(
344
+ self.upload_csv,
345
+ inputs=[csv_upload, state],
346
+ outputs=[state],
347
+ )
348
+
349
+