nanye commited on
Commit
1a3143c
·
1 Parent(s): f2260db

initial commit

Browse files
Files changed (4) hide show
  1. README.md +5 -5
  2. decision_boundary.py +452 -0
  3. requirements.txt +6 -0
  4. util.py +92 -0
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
- title: Interactive Decision Boundary
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.44.1
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
1
  ---
2
+ title: ELVIS Decision Boundary
3
+ emoji: 📊
4
+ colorFrom: green
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.44.1
8
+ app_file: decision_boundary.py
9
  pinned: false
10
  ---
11
 
decision_boundary.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import inspect
4
+ import numpy as np
5
+ import pandas as pd
6
+ import io
7
+ from PIL import Image
8
+ import sklearn
9
+ from sklearn.linear_model import LogisticRegression
10
+ from sklearn.svm import LinearSVC
11
+ from sklearn.base import ClassifierMixin
12
+ from sklearn.datasets import load_iris
13
+ from sklearn.decomposition import PCA
14
+ from sklearn.metrics import classification_report
15
+
16
+ from collections import deque
17
+
18
+ from util import *
19
+
20
+ import logging
21
+ # Configure the logger once at the start of your program
22
+ logging.basicConfig(
23
+ level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL)
24
+ format="%(asctime)s [%(levelname)s] %(message)s", # log format
25
+ )
26
+ logger = logging.getLogger("ELVIS")
27
+
28
+ # TODO:
29
+ # - support for session: load a previous session and continue from there
30
+
31
+ def label2color(labels):
32
+ color_deque = deque(["red", "green", "blue", "yellow", "orange", "purple", "pink", "brown", "gray", "black"])
33
+
34
+ colors = []
35
+
36
+ for label in labels:
37
+ if label.lower() in color_deque:
38
+ colors.append(label.lower())
39
+ color_deque.remove(label.lower())
40
+ else:
41
+ colors.append(color_deque.popleft())
42
+
43
+ return colors
44
+
45
+ def toydata():
46
+ points = [['Red', 0.12375, 0.8516666666666667],
47
+ ['Red', 0.19, 0.8916666666666666],
48
+ ['Red', 0.27375, 0.9233333333333333],
49
+ ['Blue', 0.50625, 0.785],
50
+ ['Blue', 0.38375, 0.6733333333333333],
51
+ ['Blue', 0.28875, 0.595]]
52
+ df = pd.DataFrame(points, columns=['label', 'F1', 'F2'])
53
+ return df
54
+
55
+ class CoordinateProjection:
56
+ '''
57
+ TODO: allow user to specify different coordinates
58
+ '''
59
+
60
+ def __init__(self, n_components=None, dims=[0, 1]):
61
+ self.dims = dims
62
+
63
+ def fit_transform(self, X):
64
+ print(X)
65
+ self.mean = X.mean(axis=0)
66
+ return X[:, self.dims]
67
+
68
+ def transform(self, X):
69
+ return X[:, self.dims]
70
+
71
+ def inverse_transform(self, Z):
72
+ X = np.ones((len(Z), 1)) * self.mean
73
+ X[:, self.dims] = Z
74
+ return X
75
+
76
+ class InteractiveDecisionBoundary:
77
+ def __init__(self, width, height):
78
+ # initialized in draw_plot
79
+ #self.canvas_width = -1
80
+ #self.canvas_height = -1
81
+
82
+ self.canvas_width = width
83
+ self.canvas_height = height
84
+
85
+ self.classifiers = get_sklearn_classifiers()
86
+ self.dataloaders = get_sklearn_dataloaders()
87
+
88
+ self.embedders = {
89
+ 'CoordinateProjection': CoordinateProjection,
90
+ 'GaussianRandomProjection': sklearn.random_projection.GaussianRandomProjection,
91
+ 'SparseRandomProjection': sklearn.random_projection.SparseRandomProjection,
92
+ }
93
+ module = getattr(sklearn, 'decomposition')
94
+ for cls_name, cls in inspect.getmembers(module, inspect.isclass):
95
+ self.embedders[cls_name] = cls
96
+
97
+ self.Embedder = CoordinateProjection
98
+ #self.Embedder = PCA
99
+
100
+ # default classifier model
101
+ #self.Model = LogisticRegression
102
+ self.Model = LinearSVC
103
+ self.model_args = {}
104
+
105
+ # todo: support arbitrary number of classes and user-defined class labels
106
+ #self.dataset = toydata()
107
+ #iris = load_iris(as_frame=True)
108
+ #self.dataset = pd.concat([iris.data, iris.target], axis=1)
109
+ #self.dataset = self.dataset.rename(columns={'target': 'label'})
110
+ self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2'])
111
+ self.dataset_type = 'Draw2D'
112
+ self.custom_selected = True
113
+
114
+ self.data_image = None
115
+ self.boundary_image = None
116
+
117
+ self.css ="""
118
+ #my-button {
119
+ height: 30px;
120
+ font-size: 16px;
121
+ }
122
+
123
+ #rowheight {
124
+ height: 90px;
125
+ }
126
+
127
+ .report-table {
128
+ border: 0 !important;
129
+ }
130
+ .report-table tr, .report-table th, .report-table td, .report-table tbody, .report-table thead {
131
+ border: 0 !important;
132
+ padding: 6px 12px;
133
+ text-align: center;
134
+ }"""
135
+ def plot(self, decision_boundary=False):
136
+ '''
137
+ Plot data and decision boundary with matplotlib and return as PIL image.
138
+ '''
139
+
140
+ logger.info("Initializing figure")
141
+ fig = plt.figure(figsize=(self.canvas_width/100., self.canvas_height/100.0), dpi=100)
142
+ # set entire figure to be the canvas to allow simple conversion of mouse
143
+ # position to coordinates in the figure
144
+ ax = fig.add_axes([0., 0., 1., 1.]) #
145
+ ax.margins(x=0, y=0) # no padding in both directions
146
+
147
+ if self.dataset_type == 'Draw2D':
148
+ # draw canvas boundary
149
+ #ax.scatter([0, 0, 1, 1], [0, 1, 0, 1], color='brown')
150
+ ax.plot([0, 0, 1, 1, 0], [0, 1, 1, 0, 0], color='brown')
151
+ ax.axis('off')
152
+
153
+ # TODO: allow showing x and y axes with ticks and labels
154
+ if (self.dataset is not None and len(self.dataset) > 0):
155
+ try:
156
+ X = self.dataset.loc[:, self.dataset.columns != 'target'].values
157
+ y = self.dataset.target.values
158
+ logger.info("Data:\n" + str(X))
159
+ logger.info("Target:\n" + str(y))
160
+
161
+ # compute embedding
162
+ embedder = self.Embedder(n_components=2)
163
+ self.embedder = embedder
164
+ logger.info("Embedder = " + str(self.embedder))
165
+
166
+ Z = embedder.fit_transform(X)
167
+ logger.info("Projected data:\n" + str(Z))
168
+
169
+ #ax.set_title("Click to add points")
170
+ labels = np.unique(y)
171
+ colors = label2color(labels)
172
+ logger.info("Classes:\n" + str(labels))
173
+ logger.info("Colors:\n" + str(colors))
174
+ l2c = dict(zip(labels, colors))
175
+
176
+ # scatter plots for data
177
+ for l, label in enumerate(labels):
178
+ #print('class', label)
179
+ #ax.scatter(*zip(*self.dataset[self.dataset.label == label].features), color=label, label=label)
180
+ subset = Z[y == label]
181
+ ax.scatter(subset[:, 0], subset[:, 1], color=colors[l], label=label)
182
+ ax.legend()
183
+
184
+ # plot the decision boundary
185
+ if decision_boundary:
186
+ model = self.Model(**parse_param_string(self.model_args))
187
+ model.fit(X, y)
188
+ self.model = model
189
+
190
+ # plot decision boundary in the projected space
191
+ # xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), 1, 100), np.linspace(0, 1, 100))
192
+ if self.dataset_type == 'Draw2D':
193
+ xx, yy = np.meshgrid(np.linspace(0, 1, 100),
194
+ np.linspace(0, 1, 100))
195
+ else:
196
+ xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), Z[:, 0].max(), 100),
197
+ np.linspace(Z[:, 1].min(), Z[:, 1].max(), 100))
198
+
199
+ grid = np.c_[xx.ravel(), yy.ravel()]
200
+ #scores = clf.decision_function(grid)[:, 1].reshape(xx.shape)
201
+ #scores = clf.decision_function(grid).reshape(xx.shape)
202
+ #ax.contour(xx, yy, scores)#, levels=[0], colors="black", linestyles="--")
203
+ print('grid', grid)
204
+ print('inverse', embedder.inverse_transform(grid))
205
+ preds = model.predict(embedder.inverse_transform(grid)).reshape(xx.shape)
206
+ #print(preds.shape, xx.shape, yy.shape)
207
+ ax.scatter(xx.ravel(), yy.ravel(), c=[l2c[l] for l in preds.ravel()], s=1, alpha=0.5)
208
+ except Exception as e:
209
+ raise gr.Error(f"⚠️ {e}")
210
+
211
+ buf = io.BytesIO()
212
+ ax.figure.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
213
+ plt.close(fig)
214
+ buf.seek(0)
215
+ img = Image.open(buf)
216
+
217
+ # TODO: add a save function for saving screenshot
218
+ #img.save('image.png')
219
+
220
+ return img
221
+
222
+ def add_point(self, evt: gr.SelectData, label):
223
+ '''
224
+ Mouse click to add a point.
225
+ '''
226
+ if self.custom_selected:
227
+ if self.dataset_type != 'Draw2D':
228
+ self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2'])
229
+ self.dataset_type = 'Draw2D'
230
+
231
+ # normalize clicked position to [0, 1]
232
+ x = evt.index[0] / self.canvas_width
233
+ y = 1 - evt.index[1] / self.canvas_height # flip y-axis to match matplotlib
234
+
235
+ self.dataset.loc[len(self.dataset)] = [label, x, y]
236
+
237
+ logger.info(f'clicked ({evt.index[0]}, {evt.index[1]}), mapped to ({x}, {y})')
238
+
239
+ vis = self.plot()
240
+ data_table = gr.Dataframe(value = self.dataset[:100], visible=True,
241
+ headers=list(self.dataset.columns))
242
+
243
+ return vis, data_table
244
+
245
+ # train a model and show decision boundary
246
+ def train(self):
247
+ image = self.plot(decision_boundary=True)
248
+ X = self.dataset.loc[:, self.dataset.columns != 'target'].values
249
+ y = self.dataset.target.values
250
+ pred = self.model.predict(X)
251
+ df = pd.DataFrame(classification_report(y, pred, output_dict=True)).T
252
+ summary = df.to_html(classes="report-table", float_format="%.2f")
253
+
254
+ return image, gr.HTML(visible=True), "<b>Classification report</b><br>" + summary
255
+
256
+ # clear dataset and replot
257
+ def clear(self):
258
+ self.dataset = self.dataset[0:0]
259
+ return self.plot()
260
+
261
+ # save dataset
262
+ def save(self):
263
+ # TODO: allow user-specified filename
264
+ self.dataset.to_csv('dataset.csv', index=False)
265
+
266
+ def update_model(self, classifier_name):
267
+ self.Model = self.classifiers[classifier_name]
268
+ self.args_textbox.value = ""
269
+ logger.info(f'Updated model to {self.model}')
270
+
271
+ return ""
272
+
273
+ def update_args(self, model_args):
274
+ self.model_args = model_args
275
+ print('updated model_args:', self.model_args)
276
+
277
+ def update_embedder(self, embedder):
278
+ self.Embedder = self.embedders[embedder]
279
+ print('updated embedder:', self.Embedder)
280
+ return self.plot()
281
+
282
+ def handle_dataset_radio(self, type):
283
+ if type == 'Draw2D':
284
+ self.custom_selected = True
285
+ return gr.File(visible=False), gr.Dropdown(visible=False), gr.Dropdown(visible=False), gr.Textbox(visible=True), gr.Button(visible=True), gr.Button(visible=True)
286
+ elif type == 'Upload':
287
+ self.custom_selected = False
288
+ return gr.File(visible=True), gr.Dropdown(visible=False), gr.Dropdown(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
289
+ elif type == 'sklearn':
290
+ self.custom_selected = False
291
+ return gr.File(visible=False), gr.Dropdown(visible=True), gr.Dropdown(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
292
+ else:
293
+ # TODO: better error handling
294
+ print('Error - unknown dataset type:', type)
295
+
296
+ def load_local_data_and_plot(self, filename):
297
+ if filename is not None:
298
+ self.dataset = read(filename)
299
+ self.dataset.target = self.dataset.target.astype(str)
300
+ self.dataset_type = 'Upload'
301
+ logger.info(f'Loaded dataset from {filename}')
302
+
303
+ vis = self.plot()
304
+ #data_html = self.dataset.to_html(classes="report-table", float_format="%.2f")
305
+ data_table = gr.Dataframe(value = self.dataset[:100], visible=True,
306
+ headers=list(self.dataset.columns))
307
+
308
+ return vis, data_table
309
+
310
+ def load_sklearn_data_and_plot(self, datasetname):
311
+ if datasetname is not None:
312
+ dataset = self.dataloaders[datasetname]()
313
+ X = dataset.data
314
+ y = dataset.target
315
+ if hasattr(dataset, 'feature_names'):
316
+ feature_names = dataset.feature_names
317
+ else:
318
+ feature_names = ['F{%d}' % i for i in range(len(X[0]))]
319
+
320
+ if hasattr(dataset, 'target_names'):
321
+ labels = dataset.target_names
322
+ else:
323
+ labels = ['C{%d}' % i for i in range(len(np.unique(y)))]
324
+
325
+ self.dataset = pd.DataFrame(X, columns=feature_names)
326
+ self.dataset['target'] = y.astype(str)
327
+ self.dataset_type = 'sklearn'
328
+ logger.info(f'Loaded dataset {datasetname}')
329
+
330
+ vis = self.plot()
331
+ #data_html = self.dataset.to_html(classes="report-table", float_format="%.2f")
332
+ data_table = gr.Dataframe(value = self.dataset[:100], visible=True,
333
+ headers=list(self.dataset.columns))
334
+ return vis, data_table
335
+
336
+ def launch(self):
337
+ # build the Gradio interface
338
+ with gr.Blocks(css=self.css) as demo:
339
+ # app title
340
+ gr.Markdown("<div style='text-align:left; font-size:40px; font-weight: bold;'>ELVIS Interactive Decision Boundary Visualizer</div>")
341
+
342
+ # GUI elements and layout
343
+ with gr.Row():
344
+ with gr.Column(scale=2):
345
+ self.data_image = gr.Image(value=self.plot(), container=True)
346
+
347
+ with gr.Column(scale=1):
348
+ with gr.Tab("Dataset"):
349
+ dataset_radio = gr.Radio(["Draw2D", "Upload", "sklearn"],
350
+ value="Draw2D", label="Dataset type", elem_id="rowheight")
351
+
352
+ # upload data
353
+ file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight")
354
+ self.file_chooser = file_chooser
355
+
356
+ # sklearn data dropdown menu
357
+ sklearn_data_selector = gr.Dropdown(choices=self.dataloaders,
358
+ label='Select dataset',
359
+ value='None',
360
+ visible=False,
361
+ allow_custom_value=True)
362
+ self.sklearn_data_selector = sklearn_data_selector
363
+
364
+ # embedder
365
+ embedder_selector = gr.Dropdown(choices=self.embedders,
366
+ label='Select embedder',
367
+ value='CoordinateProjection',
368
+ visible=False,
369
+ allow_custom_value=True)
370
+
371
+ # custom data
372
+ label = gr.Radio(["Red", "Green", "Blue"], value="Red", label="Choose point label", visible=True, elem_id="rowheight")
373
+ self.label = label
374
+
375
+ with gr.Row():
376
+ btn_clear = gr.Button("Clear", visible=True, elem_id="my-button")
377
+ self.btn_clear = btn_clear
378
+
379
+ btn_save = gr.Button("Save", visible=True, elem_id="my-button")
380
+ self.btn_save = btn_save
381
+
382
+ #data_html = gr.HTML(visible=True)
383
+ data_table = gr.Dataframe(visible=False)
384
+
385
+ # classifier selector
386
+ with gr.Tab("Classifier"):
387
+ # specify model
388
+ model_selector = gr.Dropdown(choices=self.classifiers,
389
+ #label='',
390
+ #value='Select classifier',
391
+ label='Select Classifier',
392
+ value='LinearSVC',
393
+ allow_custom_value=True)
394
+ self.model_selector = model_selector
395
+
396
+ # specify arguments
397
+ args_textbox = gr.Textbox(label="Classifier arguments")
398
+ self.args_textbox = args_textbox
399
+
400
+ model_selector.change(fn=self.update_model, inputs=model_selector, outputs=args_textbox)
401
+
402
+ btn_train = gr.Button("Train Model")
403
+
404
+ classification_summary = gr.HTML(visible=False)
405
+
406
+ with gr.Tab("Export"):
407
+ btn_export_data = gr.Button('Data')
408
+
409
+ btn_export_model = gr.Button('Model')
410
+
411
+ btn_export_code = gr.Button('Code')
412
+
413
+ with gr.Tab("Options"):
414
+ pass
415
+
416
+ with gr.Tab("Usage"):
417
+ pass
418
+
419
+
420
+ # event handlers for GUI elements
421
+ self.data_image.select(self.add_point, inputs=label,
422
+ outputs=(self.data_image, data_table))
423
+
424
+ dataset_radio.change(fn=self.handle_dataset_radio,
425
+ inputs=dataset_radio,
426
+ outputs=(file_chooser, sklearn_data_selector, embedder_selector, label, btn_clear, btn_save))
427
+
428
+ # events for custom dataset
429
+ btn_clear.click(fn=self.clear, outputs=self.data_image)
430
+ btn_save.click(fn=self.save)
431
+
432
+ # events for local dataset
433
+ file_chooser.change(fn=self.load_local_data_and_plot,
434
+ inputs=file_chooser,
435
+ outputs=(self.data_image, data_table))
436
+
437
+ # events for sklearn dataset
438
+ sklearn_data_selector.change(fn=self.load_sklearn_data_and_plot,
439
+ inputs=sklearn_data_selector,
440
+ outputs=(self.data_image, data_table))
441
+
442
+ embedder_selector.change(fn=self.update_embedder,
443
+ inputs=embedder_selector,
444
+ outputs=self.data_image)
445
+
446
+ btn_train.click(fn=self.update_args, inputs=args_textbox)
447
+ btn_train.click(fn=self.train, outputs=(self.data_image, classification_summary, classification_summary))
448
+
449
+ demo.launch()
450
+
451
+ visualizer = InteractiveDecisionBoundary(width=1200, height=900)
452
+ visualizer.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ pandas
4
+ scikit-learn
5
+ mpu
6
+ pillow
util.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import sklearn
3
+ from sklearn.linear_model import LogisticRegression
4
+ from sklearn.base import ClassifierMixin
5
+ from sklearn.datasets import *
6
+ import pkgutil
7
+ import importlib
8
+ import warnings
9
+ import ast
10
+ import pandas as pd
11
+
12
+ def safe_import_module(name):
13
+ try:
14
+ with warnings.catch_warnings():
15
+ warnings.simplefilter("ignore")
16
+ return importlib.import_module(name)
17
+ except Exception:
18
+ return None # or raise/log if desired
19
+
20
+ def get_sklearn_classifiers():
21
+ classifiers = {}
22
+
23
+ #for modname in dir(sklearn):
24
+ for _, modname, _ in pkgutil.walk_packages(sklearn.__path__, prefix="sklearn."):
25
+ if '._' in modname: # exclude hidden modules
26
+ continue
27
+
28
+ if modname.count('.') > 1: # exclude modules more than two levels deep
29
+ continue
30
+
31
+ #print(modname)
32
+ try:
33
+ #with warnings.catch_warnings():
34
+ #warnings.simplefilter("ignore")
35
+
36
+ module = importlib.import_module(modname)
37
+ for cls_name, cls in inspect.getmembers(module, inspect.isclass):
38
+ if '._' not in cls_name and ('ClassifierMixin' not in cls_name):
39
+ if issubclass(cls, ClassifierMixin) and cls.__module__.startswith("sklearn"):
40
+ classifiers[cls_name] = cls
41
+ #classifiers.append(f"{cls.__module__}.{cls_name}")
42
+ except:
43
+ continue
44
+
45
+ return classifiers
46
+
47
+ def get_sklearn_dataloaders():
48
+ dataloaders = {}
49
+ dataloaders['20newsgroup'] = fetch_20newsgroups
50
+ dataloaders['20newsgroup_vectorized'] = fetch_20newsgroups_vectorized
51
+ dataloaders['covtype'] = fetch_covtype
52
+ dataloaders['kddcup99'] = fetch_kddcup99
53
+ dataloaders['iris'] = load_iris
54
+
55
+ return dataloaders
56
+
57
+ def parse_param_string(param_str):
58
+ param_str = param_str.replace("*,", "") # Remove '*' if present
59
+ params = {}
60
+ for item in param_str.split(','):
61
+ if not item.strip():
62
+ continue
63
+ if '=' not in item:
64
+ continue
65
+ key, value = item.split('=', 1)
66
+ key = key.strip()
67
+ try:
68
+ value = ast.literal_eval(value.strip())
69
+ except Exception:
70
+ value = value.strip() # fallback: treat as string
71
+ params[key] = value
72
+ return params
73
+
74
+ def read(filename):
75
+ if filename.endswith(".csv"):
76
+ return pd.read_csv(filename)
77
+ elif filename.endswith(".xlsx") or filename.endswith(".xls"):
78
+ return pd.read_excel(filename)
79
+ elif filename.endswith(".parquet"):
80
+ return pd.read_parquet(filename)
81
+ elif filename.endswith(".feather"):
82
+ return pd.read_feather(filename)
83
+ elif filename.endswith(".json"):
84
+ return pd.read_json(filename)
85
+ else:
86
+ raise ValueError("Unsupported file format.")
87
+
88
+ if __name__ == '__main__':
89
+ #print(classifier_list)
90
+ s = "penalty='l2', *, dual=False, tol=0.0001, C=1.0, fit_intercept=True, intercept_scaling=1"
91
+ parsed = parse_param_string(s)
92
+ print(parsed)