Spaces:
Sleeping
Sleeping
| import numpy as np | |
| class Oracle: | |
| oracle_d = {'exp': np.exp, 'sqrt': np.sqrt, 'pi': np.pi, 'cos': np.cos, 'sin': np.sin, 'tan': np.tan, | |
| 'tanh': np.tanh, 'ln': np.log, 'arcsin': np.arcsin} | |
| def __init__(self, nvariables, f=None, form=None, variable_names=None, range_restriction={}, id=None): | |
| """ | |
| nvariables: is the number of variables the function takes in | |
| f: takes in an X of shape (n, nvariables) and returns f(X) of shape (n,) | |
| form: String Def of the function | |
| variable_names: variable names used in form | |
| Range_restrictions: Dictionary of form {variable_index: (low, high)} | |
| """ | |
| self.nvariables = nvariables | |
| if f is None and form is None: | |
| raise ValueError("f and form are both none in Oracle initialization. Specify at least one") | |
| if f is not None and form is not None: | |
| raise ValueError("f and form are both not none, pick only one") | |
| if form is not None and variable_names is None: | |
| raise ValueError("If form is provided then variable_names must also be provided") | |
| if form is not None: | |
| self.form = form | |
| self.variable_names = variable_names | |
| self.use_func = False | |
| self.d = Oracle.oracle_d.copy() | |
| for var_name in variable_names: | |
| self.d[var_name] = None | |
| else: | |
| # f is not None | |
| self.func = f | |
| self.use_func = True | |
| self.ranges = [] | |
| for i in range(nvariables): | |
| if i in range_restriction: | |
| self.ranges.append(range_restriction[i]) | |
| else: | |
| self.ranges.append(None) | |
| if id is not None: | |
| self.id = id | |
| return | |
| def f(self, X): | |
| """ | |
| X is of shape (n, nvariables) | |
| """ | |
| if self.invalid_input(X): | |
| raise ValueError("Invalid input to Oracle") | |
| if self.use_func: | |
| return self.func(X) | |
| else: | |
| return self.form_f(X) | |
| def form_f(self, X): | |
| """ | |
| Returns the function output using form | |
| """ | |
| for i, var in enumerate(self.variable_names): | |
| self.d[var] = X[:, i] | |
| return eval(self.form, self.d) | |
| def invalid_input(self, X): | |
| """ | |
| Returns true if any of the following are true | |
| X has more or less variables than nvariables | |
| X has a value in a restricted range variable outside said range | |
| """ | |
| if X.shape[1] != self.nvariables: | |
| return True | |
| for i, r in enumerate(self.ranges): | |
| if r is None: | |
| continue | |
| else: | |
| low = r[0] | |
| high = r[1] | |
| low_check = all(low <= X[:, i]) | |
| high_check = all(X[:, i] <= high) | |
| if not low_check or not high_check: | |
| return True | |
| def __str__(self): | |
| if self.id: | |
| return str(self.id) | |
| elif self.form: | |
| return str(self.form) | |
| else: | |
| return "<Un named Oracle>" | |
| def from_problem(problem): | |
| """ | |
| Static function to return an oracle when given an instance of class problem. | |
| """ | |
| return Oracle(nvariables=problem.n_vars, f=None, form=problem.form, variable_names=problem.var_names, | |
| id=problem.eq_id) |