Spaces:
Sleeping
Sleeping
Commit ·
07df3e3
1
Parent(s): 1cfecae
Fix bug in using csv dataset
Browse files- dataset.py +7 -7
- regularization.py +1 -1
dataset.py
CHANGED
|
@@ -89,7 +89,7 @@ class Dataset:
|
|
| 89 |
)
|
| 90 |
|
| 91 |
def _get_data(self):
|
| 92 |
-
if self.mode == "generate":
|
| 93 |
return get_data_points(
|
| 94 |
function=self.function,
|
| 95 |
x1lim=self.x1lim,
|
|
@@ -101,14 +101,14 @@ class Dataset:
|
|
| 101 |
|
| 102 |
elif self.mode == "csv":
|
| 103 |
if self.csv_path is None:
|
| 104 |
-
|
| 105 |
|
| 106 |
df = pd.read_csv(self.csv_path)
|
| 107 |
-
if df.shape[1] !=
|
| 108 |
-
raise ValueError("CSV file must have exactly
|
| 109 |
|
| 110 |
-
x = df.iloc[:,
|
| 111 |
-
y = df.iloc[:, 1].values
|
| 112 |
return x, y
|
| 113 |
|
| 114 |
else:
|
|
@@ -292,7 +292,7 @@ class DatasetView:
|
|
| 292 |
regenerate = gr.Button("Regenerate Data")
|
| 293 |
|
| 294 |
csv_upload = gr.File(
|
| 295 |
-
label="Upload CSV file",
|
| 296 |
file_types=['.csv'],
|
| 297 |
visible=False, # function mode is default
|
| 298 |
)
|
|
|
|
| 89 |
)
|
| 90 |
|
| 91 |
def _get_data(self):
|
| 92 |
+
if self.mode == "generate" or self.csv_path is None:
|
| 93 |
return get_data_points(
|
| 94 |
function=self.function,
|
| 95 |
x1lim=self.x1lim,
|
|
|
|
| 101 |
|
| 102 |
elif self.mode == "csv":
|
| 103 |
if self.csv_path is None:
|
| 104 |
+
raise RuntimeError("Something is wrong")
|
| 105 |
|
| 106 |
df = pd.read_csv(self.csv_path)
|
| 107 |
+
if df.shape[1] != 3:
|
| 108 |
+
raise ValueError("CSV file must have exactly three columns")
|
| 109 |
|
| 110 |
+
x = df.iloc[:, :-1].values
|
| 111 |
+
y = df.iloc[:, -1].values
|
| 112 |
return x, y
|
| 113 |
|
| 114 |
else:
|
|
|
|
| 292 |
regenerate = gr.Button("Regenerate Data")
|
| 293 |
|
| 294 |
csv_upload = gr.File(
|
| 295 |
+
label="Upload CSV file - must have columns: (x1, x2, y)",
|
| 296 |
file_types=['.csv'],
|
| 297 |
visible=False, # function mode is default
|
| 298 |
)
|
regularization.py
CHANGED
|
@@ -16,7 +16,7 @@ logging.basicConfig(
|
|
| 16 |
)
|
| 17 |
logger = logging.getLogger("ELVIS")
|
| 18 |
|
| 19 |
-
from dataset import Dataset, DatasetView
|
| 20 |
|
| 21 |
def min_corresponding_entries(W1, W2, w1, tol=0.1):
|
| 22 |
mask = (W1 <= w1)
|
|
|
|
| 16 |
)
|
| 17 |
logger = logging.getLogger("ELVIS")
|
| 18 |
|
| 19 |
+
from dataset import Dataset, DatasetView
|
| 20 |
|
| 21 |
def min_corresponding_entries(W1, W2, w1, tol=0.1):
|
| 22 |
mask = (W1 <= w1)
|