joel-woodfield commited on
Commit
07df3e3
·
1 Parent(s): 1cfecae

Fix bug in using csv dataset

Browse files
Files changed (2) hide show
  1. dataset.py +7 -7
  2. regularization.py +1 -1
dataset.py CHANGED
@@ -89,7 +89,7 @@ class Dataset:
89
  )
90
 
91
  def _get_data(self):
92
- if self.mode == "generate":
93
  return get_data_points(
94
  function=self.function,
95
  x1lim=self.x1lim,
@@ -101,14 +101,14 @@ class Dataset:
101
 
102
  elif self.mode == "csv":
103
  if self.csv_path is None:
104
- return np.array([]), np.array([])
105
 
106
  df = pd.read_csv(self.csv_path)
107
- if df.shape[1] != 2:
108
- raise ValueError("CSV file must have exactly two columns")
109
 
110
- x = df.iloc[:, 0].values.reshape(-1, 1)
111
- y = df.iloc[:, 1].values
112
  return x, y
113
 
114
  else:
@@ -292,7 +292,7 @@ class DatasetView:
292
  regenerate = gr.Button("Regenerate Data")
293
 
294
  csv_upload = gr.File(
295
- label="Upload CSV file",
296
  file_types=['.csv'],
297
  visible=False, # function mode is default
298
  )
 
89
  )
90
 
91
  def _get_data(self):
92
+ if self.mode == "generate" or self.csv_path is None:
93
  return get_data_points(
94
  function=self.function,
95
  x1lim=self.x1lim,
 
101
 
102
  elif self.mode == "csv":
103
  if self.csv_path is None:
104
+ raise RuntimeError("Something is wrong")
105
 
106
  df = pd.read_csv(self.csv_path)
107
+ if df.shape[1] != 3:
108
+ raise ValueError("CSV file must have exactly three columns")
109
 
110
+ x = df.iloc[:, :-1].values
111
+ y = df.iloc[:, -1].values
112
  return x, y
113
 
114
  else:
 
292
  regenerate = gr.Button("Regenerate Data")
293
 
294
  csv_upload = gr.File(
295
+ label="Upload CSV file - must have columns: (x1, x2, y)",
296
  file_types=['.csv'],
297
  visible=False, # function mode is default
298
  )
regularization.py CHANGED
@@ -16,7 +16,7 @@ logging.basicConfig(
16
  )
17
  logger = logging.getLogger("ELVIS")
18
 
19
- from dataset import Dataset, DatasetView, get_function
20
 
21
  def min_corresponding_entries(W1, W2, w1, tol=0.1):
22
  mask = (W1 <= w1)
 
16
  )
17
  logger = logging.getLogger("ELVIS")
18
 
19
+ from dataset import Dataset, DatasetView
20
 
21
  def min_corresponding_entries(W1, W2, w1, tol=0.1):
22
  mask = (W1 <= w1)