Joel Woodfield commited on
Commit
0c4e50c
·
1 Parent(s): b43e0ab

Add function input error handling

Browse files
backend/src/__pycache__/manager.cpython-312.pyc CHANGED
Binary files a/backend/src/__pycache__/manager.cpython-312.pyc and b/backend/src/__pycache__/manager.cpython-312.pyc differ
 
backend/src/manager.py CHANGED
@@ -4,7 +4,12 @@ import matplotlib.lines as mlines
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  from matplotlib.figure import Figure
7
- from sympy import sympify
 
 
 
 
 
8
 
9
  from logic import (
10
  DataGenerationOptions,
@@ -75,10 +80,31 @@ class Manager:
75
  y_col: int,
76
  ) -> Dataset:
77
  if dataset_type == "Generate":
 
 
 
 
 
 
 
 
 
 
 
78
  try:
79
- parsed_function = sympify(function)
 
 
 
 
 
80
  except Exception as e:
81
- raise ValueError(f"Invalid function {e}")
 
 
 
 
 
82
 
83
  x1_range = self._parse_range(x1_range_input)
84
  x2_range = self._parse_range(x2_range_input)
 
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  from matplotlib.figure import Figure
7
+ from sympy import sympify, symbols, sin, cos, exp
8
+ from sympy.parsing.sympy_parser import (
9
+ standard_transformations,
10
+ implicit_multiplication_application,
11
+ parse_expr,
12
+ )
13
 
14
  from logic import (
15
  DataGenerationOptions,
 
80
  y_col: int,
81
  ) -> Dataset:
82
  if dataset_type == "Generate":
83
+ x1, x2 = symbols("x1 x2")
84
+ allowed_locals = {
85
+ "x1": x1,
86
+ "x2": x2,
87
+ "sin": sin,
88
+ "cos": cos,
89
+ "exp": exp,
90
+ }
91
+ if not function.strip():
92
+ raise ValueError("Function cannot be empty")
93
+
94
  try:
95
+ parsed_function = parse_expr(
96
+ function,
97
+ local_dict=allowed_locals,
98
+ transformations=standard_transformations + (implicit_multiplication_application,),
99
+ evaluate=True,
100
+ )
101
  except Exception as e:
102
+ raise ValueError(f"Invalid function: {e}")
103
+
104
+ unknown_symbols = parsed_function.free_symbols - {x1, x2}
105
+ if unknown_symbols:
106
+ unknown_names = ", ".join(sorted(str(s) for s in unknown_symbols))
107
+ raise ValueError(f"Unknown variable(s): {unknown_names}. Allowed: x1, x2")
108
 
109
  x1_range = self._parse_range(x1_range_input)
110
  x2_range = self._parse_range(x2_range_input)
frontends/gradio/main.py CHANGED
@@ -71,26 +71,30 @@ def handle_generate_plots(
71
  w2_range_input: str,
72
  resolution: int,
73
  ) -> tuple[Manager, Figure, Figure, Figure]:
74
- return manager.handle_generate_plots(
75
- dataset_type,
76
- function,
77
- x1_range_input,
78
- x2_range_input,
79
- x_selection_method,
80
- sigma,
81
- nsample,
82
- csv_file,
83
- has_header,
84
- x1_col,
85
- x2_col,
86
- y_col,
87
- loss_type,
88
- regularizer_type,
89
- reg_levels_input,
90
- w1_range_input,
91
- w2_range_input,
92
- resolution,
93
- )
 
 
 
 
94
 
95
  def handle_use_suggested_settings(
96
  manager: Manager,
@@ -107,20 +111,23 @@ def handle_use_suggested_settings(
107
  x2_col: int,
108
  y_col: int,
109
  ) -> tuple[Manager, str, str, str]:
110
- return manager.handle_use_suggested_settings(
111
- dataset_type,
112
- function,
113
- x1_range_input,
114
- x2_range_input,
115
- x_selection_method,
116
- sigma,
117
- nsample,
118
- csv_file,
119
- has_header,
120
- x1_col,
121
- x2_col,
122
- y_col,
123
- )
 
 
 
124
 
125
 
126
  def launch():
 
71
  w2_range_input: str,
72
  resolution: int,
73
  ) -> tuple[Manager, Figure, Figure, Figure]:
74
+ try:
75
+ return manager.handle_generate_plots(
76
+ dataset_type,
77
+ function,
78
+ x1_range_input,
79
+ x2_range_input,
80
+ x_selection_method,
81
+ sigma,
82
+ nsample,
83
+ csv_file,
84
+ has_header,
85
+ x1_col,
86
+ x2_col,
87
+ y_col,
88
+ loss_type,
89
+ regularizer_type,
90
+ reg_levels_input,
91
+ w1_range_input,
92
+ w2_range_input,
93
+ resolution,
94
+ )
95
+ except Exception as e:
96
+ raise gr.Error("Error generating plots: " + str(e))
97
+
98
 
99
  def handle_use_suggested_settings(
100
  manager: Manager,
 
111
  x2_col: int,
112
  y_col: int,
113
  ) -> tuple[Manager, str, str, str]:
114
+ try:
115
+ return manager.handle_use_suggested_settings(
116
+ dataset_type,
117
+ function,
118
+ x1_range_input,
119
+ x2_range_input,
120
+ x_selection_method,
121
+ sigma,
122
+ nsample,
123
+ csv_file,
124
+ has_header,
125
+ x1_col,
126
+ x2_col,
127
+ y_col,
128
+ )
129
+ except Exception as e:
130
+ raise gr.Error("Error computing suggested settings: " + str(e))
131
 
132
 
133
  def launch():