joel-woodfield commited on
Commit
8caaf20
·
1 Parent(s): b70a3e1

Add more symbol support in custom function expressions

Browse files
backend/src/optimization_manager.py CHANGED
@@ -1,5 +1,26 @@
1
  import numpy as np
2
- from sympy import sympify, lambdify
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  from optimization_logic import *
5
 
@@ -83,24 +104,14 @@ class OptimizationManager:
83
 
84
  # function expression check
85
  try:
86
- expr = sympify(function)
87
- symbols = {s.name for s in expr.free_symbols}
88
- if mode == "univariate":
89
- return symbols in {frozenset({'x'}), frozenset(set())}
90
- elif mode == "bivariate":
91
- return symbols in {
92
- frozenset({'x', 'y'}),
93
- frozenset({'x'}),
94
- frozenset({'y'}),
95
- frozenset(set()),
96
- }
97
- else:
98
- return False
99
-
100
  except Exception as e:
101
- pass
102
 
103
- return False
 
 
 
104
 
105
  def _function_changed(self, function: str, mode: str) -> bool:
106
  function = function.strip()
@@ -114,8 +125,53 @@ class OptimizationManager:
114
  except Exception as e:
115
  self.trajectory_values = {"x": [], "y": []}
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def _compute_function_values(self, function: str, mode: str, xlim: list, ylim: list) -> None:
118
- expr = sympify(function)
 
119
  if mode == "univariate":
120
  x = np.linspace(xlim[0], xlim[1], 100)
121
  f = lambdify('x', expr, modules=['numpy'])
@@ -151,7 +207,7 @@ class OptimizationManager:
151
  def _compute_trajectory_values(self, settings: dict, steps: int) -> None:
152
  mode = settings.get("mode", "").lower().strip()
153
  algorithm = settings.get("algorithm", "").lower().strip().replace(" ", "_")
154
- function = sympify(settings.get("functionExpr", "").strip())
155
 
156
  if mode == "univariate":
157
  if algorithm == "gradient_descent":
 
1
  import numpy as np
2
+ from sympy import (
3
+ lambdify,
4
+ symbols,
5
+ sin,
6
+ cos,
7
+ tan,
8
+ asin,
9
+ acos,
10
+ atan,
11
+ exp,
12
+ log,
13
+ sqrt,
14
+ pi,
15
+ Abs,
16
+ )
17
+
18
+ from sympy.parsing.sympy_parser import (
19
+ standard_transformations,
20
+ implicit_multiplication_application,
21
+ convert_xor,
22
+ parse_expr,
23
+ )
24
 
25
  from optimization_logic import *
26
 
 
104
 
105
  # function expression check
106
  try:
107
+ expr = self._parse_function(function)
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
+ return False
110
 
111
+ if mode == "univariate" and symbols("y") in expr.free_symbols:
112
+ return False
113
+
114
+ return True
115
 
116
  def _function_changed(self, function: str, mode: str) -> bool:
117
  function = function.strip()
 
125
  except Exception as e:
126
  self.trajectory_values = {"x": [], "y": []}
127
 
128
+ def _parse_function(self, function: str) -> Expr:
129
+ if not function.strip():
130
+ raise ValueError("Function expression cannot be empty")
131
+
132
+ x, y = symbols("x y")
133
+ allowed_locals = {
134
+ 'x': x,
135
+ 'y': y,
136
+ 'sin': sin,
137
+ 'cos': cos,
138
+ 'tan': tan,
139
+ 'asin': asin,
140
+ 'acos': acos,
141
+ 'atan': atan,
142
+ 'log': log,
143
+ 'ln': log,
144
+ 'sqrt': sqrt,
145
+ 'abs': Abs,
146
+ 'exp': exp,
147
+ 'e': exp(1),
148
+ 'pi': pi,
149
+ 'π': pi,
150
+ }
151
+
152
+ try:
153
+ parsed_function = parse_expr(
154
+ function,
155
+ local_dict=allowed_locals,
156
+ transformations=standard_transformations + (
157
+ implicit_multiplication_application,
158
+ convert_xor,
159
+ ),
160
+ evaluate=True,
161
+ )
162
+ except Exception as e:
163
+ raise ValueError(f"Invalid function expression: {e}")
164
+
165
+ unknown_symbols = parsed_function.free_symbols - {x, y}
166
+ if unknown_symbols:
167
+ unknown_names = ", ".join(sorted(str(s) for s in unknown_symbols))
168
+ raise ValueError(f"Unknown variable(s): {unknown_names}. Allowed: x, y")
169
+
170
+ return parsed_function
171
+
172
  def _compute_function_values(self, function: str, mode: str, xlim: list, ylim: list) -> None:
173
+ expr = self._parse_function(function)
174
+
175
  if mode == "univariate":
176
  x = np.linspace(xlim[0], xlim[1], 100)
177
  f = lambdify('x', expr, modules=['numpy'])
 
207
  def _compute_trajectory_values(self, settings: dict, steps: int) -> None:
208
  mode = settings.get("mode", "").lower().strip()
209
  algorithm = settings.get("algorithm", "").lower().strip().replace(" ", "_")
210
+ function = self._parse_function(settings.get("functionExpr", "").strip())
211
 
212
  if mode == "univariate":
213
  if algorithm == "gradient_descent":
dist/assets/{index-DyZyiv0F.js → index-CZNS1f0O.js} RENAMED
The diff for this file is too large to render. See raw diff
 
dist/assets/{pyodide.worker-BeUH2O5o.js → pyodide.worker-Dr32d4MW.js} RENAMED
@@ -1,5 +1,26 @@
1
  (function(){"use strict";var i=`import numpy as np
2
- from sympy import sympify, lambdify
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  from optimization_logic import *
5
 
@@ -83,24 +104,14 @@ class OptimizationManager:
83
 
84
  # function expression check
85
  try:
86
- expr = sympify(function)
87
- symbols = {s.name for s in expr.free_symbols}
88
- if mode == "univariate":
89
- return symbols in {frozenset({'x'}), frozenset(set())}
90
- elif mode == "bivariate":
91
- return symbols in {
92
- frozenset({'x', 'y'}),
93
- frozenset({'x'}),
94
- frozenset({'y'}),
95
- frozenset(set()),
96
- }
97
- else:
98
- return False
99
-
100
  except Exception as e:
101
- pass
102
 
103
- return False
 
 
 
104
 
105
  def _function_changed(self, function: str, mode: str) -> bool:
106
  function = function.strip()
@@ -114,8 +125,53 @@ class OptimizationManager:
114
  except Exception as e:
115
  self.trajectory_values = {"x": [], "y": []}
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def _compute_function_values(self, function: str, mode: str, xlim: list, ylim: list) -> None:
118
- expr = sympify(function)
 
119
  if mode == "univariate":
120
  x = np.linspace(xlim[0], xlim[1], 100)
121
  f = lambdify('x', expr, modules=['numpy'])
@@ -151,7 +207,7 @@ class OptimizationManager:
151
  def _compute_trajectory_values(self, settings: dict, steps: int) -> None:
152
  mode = settings.get("mode", "").lower().strip()
153
  algorithm = settings.get("algorithm", "").lower().strip().replace(" ", "_")
154
- function = sympify(settings.get("functionExpr", "").strip())
155
 
156
  if mode == "univariate":
157
  if algorithm == "gradient_descent":
@@ -825,4 +881,4 @@ def adam_bivariate(
825
  "x": x_values,
826
  "y": y_values,
827
  "z": z_values,
828
- }`;const o="https://cdn.jsdelivr.net/pyodide/v0.26.1/full/pyodide.mjs";let e=null,t=null;async function f(){const{loadPyodide:n}=await import(o);e=await n({indexURL:"https://cdn.jsdelivr.net/pyodide/v0.26.1/full/"}),await e.loadPackage(["numpy","sympy"]),e.FS.writeFile("optimization_logic.py",l),e.FS.writeFile("optimization_manager.py",i),e.runPython("from optimization_manager import OptimizationManager; manager = OptimizationManager();"),t=e.globals.get("manager"),t||console.error("Failed to initialize optimization manager"),self.postMessage({type:"READY"})}function s(n){if(!n)return null;try{const a=n.toJs({dict_converter:Object.fromEntries});n.destroy&&n.destroy(),self.postMessage({type:"RESULT",data:a})}catch(a){console.error("Error handling Python result:",a)}}self.onmessage=async n=>{const a=n.data;if(!t){console.warn("Pyodide is not ready yet");return}switch(a.type){case"INIT":const r=e.toPy(a.settings);s(t.handle_update_settings(r));break;case"NEXT_STEP":s(t.handle_next_step());break;case"PREV_STEP":s(t.handle_prev_step());break;case"RESET":s(t.handle_reset());break;default:console.error("Unknown message type:",a);break}},f()})();
 
1
  (function(){"use strict";var i=`import numpy as np
2
+ from sympy import (
3
+ lambdify,
4
+ symbols,
5
+ sin,
6
+ cos,
7
+ tan,
8
+ asin,
9
+ acos,
10
+ atan,
11
+ exp,
12
+ log,
13
+ sqrt,
14
+ pi,
15
+ Abs,
16
+ )
17
+
18
+ from sympy.parsing.sympy_parser import (
19
+ standard_transformations,
20
+ implicit_multiplication_application,
21
+ convert_xor,
22
+ parse_expr,
23
+ )
24
 
25
  from optimization_logic import *
26
 
 
104
 
105
  # function expression check
106
  try:
107
+ expr = self._parse_function(function)
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
+ return False
110
 
111
+ if mode == "univariate" and symbols("y") in expr.free_symbols:
112
+ return False
113
+
114
+ return True
115
 
116
  def _function_changed(self, function: str, mode: str) -> bool:
117
  function = function.strip()
 
125
  except Exception as e:
126
  self.trajectory_values = {"x": [], "y": []}
127
 
128
+ def _parse_function(self, function: str) -> Expr:
129
+ if not function.strip():
130
+ raise ValueError("Function expression cannot be empty")
131
+
132
+ x, y = symbols("x y")
133
+ allowed_locals = {
134
+ 'x': x,
135
+ 'y': y,
136
+ 'sin': sin,
137
+ 'cos': cos,
138
+ 'tan': tan,
139
+ 'asin': asin,
140
+ 'acos': acos,
141
+ 'atan': atan,
142
+ 'log': log,
143
+ 'ln': log,
144
+ 'sqrt': sqrt,
145
+ 'abs': Abs,
146
+ 'exp': exp,
147
+ 'e': exp(1),
148
+ 'pi': pi,
149
+ 'π': pi,
150
+ }
151
+
152
+ try:
153
+ parsed_function = parse_expr(
154
+ function,
155
+ local_dict=allowed_locals,
156
+ transformations=standard_transformations + (
157
+ implicit_multiplication_application,
158
+ convert_xor,
159
+ ),
160
+ evaluate=True,
161
+ )
162
+ except Exception as e:
163
+ raise ValueError(f"Invalid function expression: {e}")
164
+
165
+ unknown_symbols = parsed_function.free_symbols - {x, y}
166
+ if unknown_symbols:
167
+ unknown_names = ", ".join(sorted(str(s) for s in unknown_symbols))
168
+ raise ValueError(f"Unknown variable(s): {unknown_names}. Allowed: x, y")
169
+
170
+ return parsed_function
171
+
172
  def _compute_function_values(self, function: str, mode: str, xlim: list, ylim: list) -> None:
173
+ expr = self._parse_function(function)
174
+
175
  if mode == "univariate":
176
  x = np.linspace(xlim[0], xlim[1], 100)
177
  f = lambdify('x', expr, modules=['numpy'])
 
207
  def _compute_trajectory_values(self, settings: dict, steps: int) -> None:
208
  mode = settings.get("mode", "").lower().strip()
209
  algorithm = settings.get("algorithm", "").lower().strip().replace(" ", "_")
210
+ function = self._parse_function(settings.get("functionExpr", "").strip())
211
 
212
  if mode == "univariate":
213
  if algorithm == "gradient_descent":
 
881
  "x": x_values,
882
  "y": y_values,
883
  "z": z_values,
884
+ }`;const o="https://cdn.jsdelivr.net/pyodide/v0.26.1/full/pyodide.mjs";let e=null,t=null;async function r(){const{loadPyodide:n}=await import(o);e=await n({indexURL:"https://cdn.jsdelivr.net/pyodide/v0.26.1/full/"}),await e.loadPackage(["numpy","sympy"]),e.FS.writeFile("optimization_logic.py",l),e.FS.writeFile("optimization_manager.py",i),e.runPython("from optimization_manager import OptimizationManager; manager = OptimizationManager();"),t=e.globals.get("manager"),t||console.error("Failed to initialize optimization manager"),self.postMessage({type:"READY"})}function s(n){if(!n)return null;try{const a=n.toJs({dict_converter:Object.fromEntries});n.destroy&&n.destroy(),self.postMessage({type:"RESULT",data:a})}catch(a){console.error("Error handling Python result:",a)}}self.onmessage=async n=>{const a=n.data;if(!t){console.warn("Pyodide is not ready yet");return}switch(a.type){case"INIT":const f=e.toPy(a.settings);s(t.handle_update_settings(f));break;case"NEXT_STEP":s(t.handle_next_step());break;case"PREV_STEP":s(t.handle_prev_step());break;case"RESET":s(t.handle_reset());break;default:console.error("Unknown message type:",a);break}},r()})();
dist/index.html CHANGED
@@ -5,7 +5,7 @@
5
  <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
  <title>Optimization</title>
8
- <script type="module" crossorigin src="/assets/index-DyZyiv0F.js"></script>
9
  <link rel="stylesheet" crossorigin href="/assets/index-CBOaLvz3.css">
10
  </head>
11
  <body>
 
5
  <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
  <title>Optimization</title>
8
+ <script type="module" crossorigin src="/assets/index-CZNS1f0O.js"></script>
9
  <link rel="stylesheet" crossorigin href="/assets/index-CBOaLvz3.css">
10
  </head>
11
  <body>
frontends/react/src/Sidebar.tsx CHANGED
@@ -19,20 +19,20 @@ const DEFAULT_HYPERPARAMETERS = {
19
 
20
 
21
  const UNIVARIATE_FUNCTION_OPTIONS = {
22
- "--Custom--": "3 * x^2 + x",
23
  "Quadratic": "x^2",
24
- "Cubic": "x^3 - 3*x^2 + 2*x",
25
- "Quartic": "x^4 - 4*x^3 + 6*x^2 - 4*x + 1",
26
  "Sine": "sin(x)",
27
- "Exponential": "exp(x) - 5",
28
  }
29
 
30
  const BIVARIATE_FUNCTION_OPTIONS = {
31
- "--Custom--": "x^2 + 3 * y^2",
32
  "Quadratic": "x^2 + y^2",
33
- "Ackley": "-20 * exp(-0.2 * sqrt(0.5 * (x^2 + y^2))) - exp(0.5 * (cos(2 * pi * x) + cos(2 * pi * y))) + e + 20",
34
- "Rasteringin": "20 + (x^2 - 10 * cos(2 * pi * x)) + (y^2 - 10 * cos(2 * pi * y))",
35
- "Rosenbrock": "(1 - x)^2 + 100 * (y - x^2)^2",
36
  }
37
 
38
 
@@ -68,7 +68,7 @@ export default function Sidebar({
68
  }
69
  }
70
 
71
- const [functionOption, setFunctionOption] = useState<string>("Quadratic");
72
 
73
  function handleFunctionOptionChange(option: string) {
74
  setFunctionOption(option);
@@ -78,7 +78,7 @@ export default function Sidebar({
78
 
79
  function handleModeChange(mode: SettingsUi["mode"]) {
80
  // When changing modes, reset function to Quadratic as some options are mode-specific
81
- const newFunctionOption = "Quadratic";
82
 
83
  const expr = mode === "Bivariate"
84
  ? BIVARIATE_FUNCTION_OPTIONS[newFunctionOption as keyof typeof BIVARIATE_FUNCTION_OPTIONS]
 
19
 
20
 
21
  const UNIVARIATE_FUNCTION_OPTIONS = {
22
+ "--Custom--": "x^2",
23
  "Quadratic": "x^2",
24
+ "Cubic": "x^3 - 3x^2 + 2x",
25
+ "Quartic": "x^4 - 4x^3 + 6x^2 - 4x + 1",
26
  "Sine": "sin(x)",
27
+ "Exponential": "exp(x)",
28
  }
29
 
30
  const BIVARIATE_FUNCTION_OPTIONS = {
31
+ "--Custom--": "x^2 + 3y^2",
32
  "Quadratic": "x^2 + y^2",
33
+ "Ackley": "-20exp(-0.2 sqrt(0.5 (x^2 + y^2))) - exp(0.5 (cos(2 pi x) + cos(2 pi y))) + e + 20",
34
+ "Rasteringin": "20 + (x^2 - 10 cos(2 * pi * x)) + (y^2 - 10 cos(2 pi y))",
35
+ "Rosenbrock": "(1 - x)^2 + 100 (y - x^2)^2",
36
  }
37
 
38
 
 
68
  }
69
  }
70
 
71
+ const [functionOption, setFunctionOption] = useState<string>("--Custom--");
72
 
73
  function handleFunctionOptionChange(option: string) {
74
  setFunctionOption(option);
 
78
 
79
  function handleModeChange(mode: SettingsUi["mode"]) {
80
  // When changing modes, reset function to Quadratic as some options are mode-specific
81
+ const newFunctionOption = "--Custom--";
82
 
83
  const expr = mode === "Bivariate"
84
  ? BIVARIATE_FUNCTION_OPTIONS[newFunctionOption as keyof typeof BIVARIATE_FUNCTION_OPTIONS]