| | """ |
| | Important note on tests in this module - the Aesara printing functions use a |
| | global cache by default, which means that tests using it will modify global |
| | state and thus not be independent from each other. Instead of using the "cache" |
| | keyword argument each time, this module uses the aesara_code_ and |
| | aesara_function_ functions defined below which default to using a new, empty |
| | cache instead. |
| | """ |
| |
|
| | import logging |
| |
|
| | from sympy.external import import_module |
| | from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy |
| |
|
| | from sympy.utilities.exceptions import ignore_warnings |
| |
|
| |
|
| | aesaralogger = logging.getLogger('aesara.configdefaults') |
| | aesaralogger.setLevel(logging.CRITICAL) |
| | aesara = import_module('aesara') |
| | aesaralogger.setLevel(logging.WARNING) |
| |
|
| |
|
| | if aesara: |
| | import numpy as np |
| | aet = aesara.tensor |
| | from aesara.scalar.basic import ScalarType |
| | from aesara.graph.basic import Variable |
| | from aesara.tensor.var import TensorVariable |
| | from aesara.tensor.elemwise import Elemwise, DimShuffle |
| | from aesara.tensor.math import Dot |
| |
|
| | from sympy.printing.aesaracode import true_divide |
| |
|
| | xt, yt, zt = [aet.scalar(name, 'floatX') for name in 'xyz'] |
| | Xt, Yt, Zt = [aet.tensor('floatX', (False, False), name=n) for n in 'XYZ'] |
| | else: |
| | |
| | disabled = True |
| |
|
| | import sympy as sy |
| | from sympy.core.singleton import S |
| | from sympy.abc import x, y, z, t |
| | from sympy.printing.aesaracode import (aesara_code, dim_handling, |
| | aesara_function) |
| |
|
| |
|
| | |
| | |
| | X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ'] |
| |
|
| | |
| | f_t = sy.Function('f')(t) |
| |
|
| |
|
| | def aesara_code_(expr, **kwargs): |
| | """ Wrapper for aesara_code that uses a new, empty cache by default. """ |
| | kwargs.setdefault('cache', {}) |
| | with warns_deprecated_sympy(): |
| | return aesara_code(expr, **kwargs) |
| |
|
| | def aesara_function_(inputs, outputs, **kwargs): |
| | """ Wrapper for aesara_function that uses a new, empty cache by default. """ |
| | kwargs.setdefault('cache', {}) |
| | with warns_deprecated_sympy(): |
| | return aesara_function(inputs, outputs, **kwargs) |
| |
|
| |
|
| | def fgraph_of(*exprs): |
| | """ Transform SymPy expressions into Aesara Computation. |
| | |
| | Parameters |
| | ========== |
| | exprs |
| | SymPy expressions |
| | |
| | Returns |
| | ======= |
| | aesara.graph.fg.FunctionGraph |
| | """ |
| | outs = list(map(aesara_code_, exprs)) |
| | ins = list(aesara.graph.basic.graph_inputs(outs)) |
| | ins, outs = aesara.graph.basic.clone(ins, outs) |
| | return aesara.graph.fg.FunctionGraph(ins, outs) |
| |
|
| |
|
| | def aesara_simplify(fgraph): |
| | """ Simplify a Aesara Computation. |
| | |
| | Parameters |
| | ========== |
| | fgraph : aesara.graph.fg.FunctionGraph |
| | |
| | Returns |
| | ======= |
| | aesara.graph.fg.FunctionGraph |
| | """ |
| | mode = aesara.compile.get_default_mode().excluding("fusion") |
| | fgraph = fgraph.clone() |
| | mode.optimizer.rewrite(fgraph) |
| | return fgraph |
| |
|
| |
|
| | def theq(a, b): |
| | """ Test two Aesara objects for equality. |
| | |
| | Also accepts numeric types and lists/tuples of supported types. |
| | |
| | Note - debugprint() has a bug where it will accept numeric types but does |
| | not respect the "file" argument and in this case and instead prints the number |
| | to stdout and returns an empty string. This can lead to tests passing where |
| | they should fail because any two numbers will always compare as equal. To |
| | prevent this we treat numbers as a separate case. |
| | """ |
| | numeric_types = (int, float, np.number) |
| | a_is_num = isinstance(a, numeric_types) |
| | b_is_num = isinstance(b, numeric_types) |
| |
|
| | |
| | if a_is_num or b_is_num: |
| | if not (a_is_num and b_is_num): |
| | return False |
| |
|
| | return a == b |
| |
|
| | |
| | a_is_seq = isinstance(a, (tuple, list)) |
| | b_is_seq = isinstance(b, (tuple, list)) |
| |
|
| | if a_is_seq or b_is_seq: |
| | if not (a_is_seq and b_is_seq) or type(a) != type(b): |
| | return False |
| |
|
| | return list(map(theq, a)) == list(map(theq, b)) |
| |
|
| | |
| | astr = aesara.printing.debugprint(a, file='str') |
| | bstr = aesara.printing.debugprint(b, file='str') |
| |
|
| | |
| | for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]: |
| | if argstr == '': |
| | raise TypeError( |
| | 'aesara.printing.debugprint(%s) returned empty string ' |
| | '(%s is instance of %r)' |
| | % (argname, argname, type(argval)) |
| | ) |
| |
|
| | return astr == bstr |
| |
|
| |
|
| | def test_example_symbols(): |
| | """ |
| | Check that the example symbols in this module print to their Aesara |
| | equivalents, as many of the other tests depend on this. |
| | """ |
| | assert theq(xt, aesara_code_(x)) |
| | assert theq(yt, aesara_code_(y)) |
| | assert theq(zt, aesara_code_(z)) |
| | assert theq(Xt, aesara_code_(X)) |
| | assert theq(Yt, aesara_code_(Y)) |
| | assert theq(Zt, aesara_code_(Z)) |
| |
|
| |
|
| | def test_Symbol(): |
| | """ Test printing a Symbol to a aesara variable. """ |
| | xx = aesara_code_(x) |
| | assert isinstance(xx, Variable) |
| | assert xx.broadcastable == () |
| | assert xx.name == x.name |
| |
|
| | xx2 = aesara_code_(x, broadcastables={x: (False,)}) |
| | assert xx2.broadcastable == (False,) |
| | assert xx2.name == x.name |
| |
|
| | def test_MatrixSymbol(): |
| | """ Test printing a MatrixSymbol to a aesara variable. """ |
| | XX = aesara_code_(X) |
| | assert isinstance(XX, TensorVariable) |
| | assert XX.broadcastable == (False, False) |
| |
|
| | @SKIP |
| | def test_MatrixSymbol_wrong_dims(): |
| | """ Test MatrixSymbol with invalid broadcastable. """ |
| | bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)] |
| | for bc in bcs: |
| | with raises(ValueError): |
| | aesara_code_(X, broadcastables={X: bc}) |
| |
|
| | def test_AppliedUndef(): |
| | """ Test printing AppliedUndef instance, which works similarly to Symbol. """ |
| | ftt = aesara_code_(f_t) |
| | assert isinstance(ftt, TensorVariable) |
| | assert ftt.broadcastable == () |
| | assert ftt.name == 'f_t' |
| |
|
| |
|
| | def test_add(): |
| | expr = x + y |
| | comp = aesara_code_(expr) |
| | assert comp.owner.op == aesara.tensor.add |
| |
|
| | def test_trig(): |
| | assert theq(aesara_code_(sy.sin(x)), aet.sin(xt)) |
| | assert theq(aesara_code_(sy.tan(x)), aet.tan(xt)) |
| |
|
| | def test_many(): |
| | """ Test printing a complex expression with multiple symbols. """ |
| | expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z) |
| | comp = aesara_code_(expr) |
| | expected = aet.exp(xt**2 + aet.cos(yt)) * aet.log(2*zt) |
| | assert theq(comp, expected) |
| |
|
| |
|
| | def test_dtype(): |
| | """ Test specifying specific data types through the dtype argument. """ |
| | for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']: |
| | assert aesara_code_(x, dtypes={x: dtype}).type.dtype == dtype |
| |
|
| | |
| | assert aesara_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64') |
| |
|
| | |
| | assert aesara_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32' |
| | assert aesara_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64' |
| |
|
| |
|
| | def test_broadcastables(): |
| | """ Test the "broadcastables" argument when printing symbol-like objects. """ |
| |
|
| | |
| | for s in [x, f_t]: |
| | for bc in [(), (False,), (True,), (False, False), (True, False)]: |
| | assert aesara_code_(s, broadcastables={s: bc}).broadcastable == bc |
| |
|
| | |
| |
|
| | def test_broadcasting(): |
| | """ Test "broadcastable" attribute after applying element-wise binary op. """ |
| |
|
| | expr = x + y |
| |
|
| | cases = [ |
| | [(), (), ()], |
| | [(False,), (False,), (False,)], |
| | [(True,), (False,), (False,)], |
| | [(False, True), (False, False), (False, False)], |
| | [(True, False), (False, False), (False, False)], |
| | ] |
| |
|
| | for bc1, bc2, bc3 in cases: |
| | comp = aesara_code_(expr, broadcastables={x: bc1, y: bc2}) |
| | assert comp.broadcastable == bc3 |
| |
|
| |
|
| | def test_MatMul(): |
| | expr = X*Y*Z |
| | expr_t = aesara_code_(expr) |
| | assert isinstance(expr_t.owner.op, Dot) |
| | assert theq(expr_t, Xt.dot(Yt).dot(Zt)) |
| |
|
| | def test_Transpose(): |
| | assert isinstance(aesara_code_(X.T).owner.op, DimShuffle) |
| |
|
| | def test_MatAdd(): |
| | expr = X+Y+Z |
| | assert isinstance(aesara_code_(expr).owner.op, Elemwise) |
| |
|
| |
|
| | def test_Rationals(): |
| | assert theq(aesara_code_(sy.Integer(2) / 3), true_divide(2, 3)) |
| | assert theq(aesara_code_(S.Half), true_divide(1, 2)) |
| |
|
| | def test_Integers(): |
| | assert aesara_code_(sy.Integer(3)) == 3 |
| |
|
| | def test_factorial(): |
| | n = sy.Symbol('n') |
| | assert aesara_code_(sy.factorial(n)) |
| |
|
| | def test_Derivative(): |
| | with ignore_warnings(UserWarning): |
| | simp = lambda expr: aesara_simplify(fgraph_of(expr)) |
| | assert theq(simp(aesara_code_(sy.Derivative(sy.sin(x), x, evaluate=False))), |
| | simp(aesara.grad(aet.sin(xt), xt))) |
| |
|
| |
|
| | def test_aesara_function_simple(): |
| | """ Test aesara_function() with single output. """ |
| | f = aesara_function_([x, y], [x+y]) |
| | assert f(2, 3) == 5 |
| |
|
| | def test_aesara_function_multi(): |
| | """ Test aesara_function() with multiple outputs. """ |
| | f = aesara_function_([x, y], [x+y, x-y]) |
| | o1, o2 = f(2, 3) |
| | assert o1 == 5 |
| | assert o2 == -1 |
| |
|
| | def test_aesara_function_numpy(): |
| | """ Test aesara_function() vs Numpy implementation. """ |
| | f = aesara_function_([x, y], [x+y], dim=1, |
| | dtypes={x: 'float64', y: 'float64'}) |
| | assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9 |
| |
|
| | f = aesara_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'}, |
| | dim=1) |
| | xx = np.arange(3).astype('float64') |
| | yy = 2*np.arange(3).astype('float64') |
| | assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9 |
| |
|
| |
|
| | def test_aesara_function_matrix(): |
| | m = sy.Matrix([[x, y], [z, x + y + z]]) |
| | expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]]) |
| | f = aesara_function_([x, y, z], [m]) |
| | np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) |
| | f = aesara_function_([x, y, z], [m], scalar=True) |
| | np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) |
| | f = aesara_function_([x, y, z], [m, m]) |
| | assert isinstance(f(1.0, 2.0, 3.0), type([])) |
| | np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected) |
| | np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected) |
| |
|
| | def test_dim_handling(): |
| | assert dim_handling([x], dim=2) == {x: (False, False)} |
| | assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True), |
| | y: (False, False)} |
| | assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)} |
| |
|
| | def test_aesara_function_kwargs(): |
| | """ |
| | Test passing additional kwargs from aesara_function() to aesara.function(). |
| | """ |
| | import numpy as np |
| | f = aesara_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore', |
| | dtypes={x: 'float64', y: 'float64', z: 'float64'}) |
| | assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9 |
| |
|
| | f = aesara_function_([x, y, z], [x+y], |
| | dtypes={x: 'float64', y: 'float64', z: 'float64'}, |
| | dim=1, on_unused_input='ignore') |
| | xx = np.arange(3).astype('float64') |
| | yy = 2*np.arange(3).astype('float64') |
| | zz = 2*np.arange(3).astype('float64') |
| | assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9 |
| |
|
| | def test_aesara_function_scalar(): |
| | """ Test the "scalar" argument to aesara_function(). """ |
| | from aesara.compile.function.types import Function |
| |
|
| | args = [ |
| | ([x, y], [x + y], None, [0]), |
| | ([X, Y], [X + Y], None, [2]), |
| | ([x, y], [x + y], {x: 0, y: 1}, [1]), |
| | ([x, y], [x + y, x - y], None, [0, 0]), |
| | ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), |
| | ] |
| |
|
| | |
| | for inputs, outputs, in_dims, out_dims in args: |
| | for scalar in [False, True]: |
| |
|
| | f = aesara_function_(inputs, outputs, dims=in_dims, scalar=scalar) |
| |
|
| | |
| | assert isinstance(f.aesara_function, Function) |
| |
|
| | |
| | in_values = [ |
| | np.ones([1 if bc else 5 for bc in i.type.broadcastable]) |
| | for i in f.aesara_function.input_storage |
| | ] |
| | out_values = f(*in_values) |
| | if not isinstance(out_values, list): |
| | out_values = [out_values] |
| |
|
| | |
| | assert len(out_dims) == len(out_values) |
| | for d, value in zip(out_dims, out_values): |
| |
|
| | if scalar and d == 0: |
| | |
| | assert isinstance(value, np.number) |
| |
|
| | else: |
| | |
| | assert isinstance(value, np.ndarray) |
| | assert value.ndim == d |
| |
|
| | def test_aesara_function_bad_kwarg(): |
| | """ |
| | Passing an unknown keyword argument to aesara_function() should raise an |
| | exception. |
| | """ |
| | raises(Exception, lambda : aesara_function_([x], [x+1], foobar=3)) |
| |
|
| |
|
| | def test_slice(): |
| | assert aesara_code_(slice(1, 2, 3)) == slice(1, 2, 3) |
| |
|
| | def theq_slice(s1, s2): |
| | for attr in ['start', 'stop', 'step']: |
| | a1 = getattr(s1, attr) |
| | a2 = getattr(s2, attr) |
| | if a1 is None or a2 is None: |
| | if not (a1 is None or a2 is None): |
| | return False |
| | elif not theq(a1, a2): |
| | return False |
| | return True |
| |
|
| | dtypes = {x: 'int32', y: 'int32'} |
| | assert theq_slice(aesara_code_(slice(x, y), dtypes=dtypes), slice(xt, yt)) |
| | assert theq_slice(aesara_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3)) |
| |
|
| | def test_MatrixSlice(): |
| | cache = {} |
| |
|
| | n = sy.Symbol('n', integer=True) |
| | X = sy.MatrixSymbol('X', n, n) |
| |
|
| | Y = X[1:2:3, 4:5:6] |
| | Yt = aesara_code_(Y, cache=cache) |
| |
|
| | s = ScalarType('int64') |
| | assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s)) |
| | assert Yt.owner.inputs[0] == aesara_code_(X, cache=cache) |
| | |
| | |
| | assert all(Yt.owner.inputs[i].data == i for i in range(1, 7)) |
| |
|
| | k = sy.Symbol('k') |
| | aesara_code_(k, dtypes={k: 'int32'}) |
| | start, stop, step = 4, k, 2 |
| | Y = X[start:stop:step] |
| | Yt = aesara_code_(Y, dtypes={n: 'int32', k: 'int32'}) |
| | |
| |
|
| | def test_BlockMatrix(): |
| | n = sy.Symbol('n', integer=True) |
| | A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD'] |
| | At, Bt, Ct, Dt = map(aesara_code_, (A, B, C, D)) |
| | Block = sy.BlockMatrix([[A, B], [C, D]]) |
| | Blockt = aesara_code_(Block) |
| | solutions = [aet.join(0, aet.join(1, At, Bt), aet.join(1, Ct, Dt)), |
| | aet.join(1, aet.join(0, At, Ct), aet.join(0, Bt, Dt))] |
| | assert any(theq(Blockt, solution) for solution in solutions) |
| |
|
| | @SKIP |
| | def test_BlockMatrix_Inverse_execution(): |
| | k, n = 2, 4 |
| | dtype = 'float32' |
| | A = sy.MatrixSymbol('A', n, k) |
| | B = sy.MatrixSymbol('B', n, n) |
| | inputs = A, B |
| | output = B.I*A |
| |
|
| | cutsizes = {A: [(n//2, n//2), (k//2, k//2)], |
| | B: [(n//2, n//2), (n//2, n//2)]} |
| | cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs] |
| | cutoutput = output.subs(dict(zip(inputs, cutinputs))) |
| |
|
| | dtypes = dict(zip(inputs, [dtype]*len(inputs))) |
| | f = aesara_function_(inputs, [output], dtypes=dtypes, cache={}) |
| | fblocked = aesara_function_(inputs, [sy.block_collapse(cutoutput)], |
| | dtypes=dtypes, cache={}) |
| |
|
| | ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs] |
| | ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype), |
| | np.eye(n).astype(dtype)] |
| | ninputs[1] += np.ones(B.shape)*1e-5 |
| |
|
| | assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5) |
| |
|
| | def test_DenseMatrix(): |
| | from aesara.tensor.basic import Join |
| |
|
| | t = sy.Symbol('theta') |
| | for MatrixType in [sy.Matrix, sy.ImmutableMatrix]: |
| | X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]]) |
| | tX = aesara_code_(X) |
| | assert isinstance(tX, TensorVariable) |
| | assert isinstance(tX.owner.op, Join) |
| |
|
| |
|
| | def test_cache_basic(): |
| | """ Test single symbol-like objects are cached when printed by themselves. """ |
| |
|
| | |
| | pairs = [ |
| | (x, sy.Symbol('x')), |
| | (X, sy.MatrixSymbol('X', *X.shape)), |
| | (f_t, sy.Function('f')(sy.Symbol('t'))), |
| | ] |
| |
|
| | for s1, s2 in pairs: |
| | cache = {} |
| | st = aesara_code_(s1, cache=cache) |
| |
|
| | |
| | assert aesara_code_(s1, cache=cache) is st |
| |
|
| | |
| | assert aesara_code_(s1, cache={}) is not st |
| |
|
| | |
| | assert aesara_code_(s2, cache=cache) is st |
| |
|
| | def test_global_cache(): |
| | """ Test use of the global cache. """ |
| | from sympy.printing.aesaracode import global_cache |
| |
|
| | backup = dict(global_cache) |
| | try: |
| | |
| | global_cache.clear() |
| |
|
| | for s in [x, X, f_t]: |
| | with warns_deprecated_sympy(): |
| | st = aesara_code(s) |
| | assert aesara_code(s) is st |
| |
|
| | finally: |
| | |
| | global_cache.update(backup) |
| |
|
| | def test_cache_types_distinct(): |
| | """ |
| | Test that symbol-like objects of different types (Symbol, MatrixSymbol, |
| | AppliedUndef) are distinguished by the cache even if they have the same |
| | name. |
| | """ |
| | symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t] |
| |
|
| | cache = {} |
| | printed = {} |
| |
|
| | for s in symbols: |
| | st = aesara_code_(s, cache=cache) |
| | assert st not in printed.values() |
| | printed[s] = st |
| |
|
| | |
| | assert len(set(map(id, printed.values()))) == len(symbols) |
| |
|
| | |
| | for s, st in printed.items(): |
| | with warns_deprecated_sympy(): |
| | assert aesara_code(s, cache=cache) is st |
| |
|
| | def test_symbols_are_created_once(): |
| | """ |
| | Test that a symbol is cached and reused when it appears in an expression |
| | more than once. |
| | """ |
| | expr = sy.Add(x, x, evaluate=False) |
| | comp = aesara_code_(expr) |
| |
|
| | assert theq(comp, xt + xt) |
| | assert not theq(comp, xt + aesara_code_(x)) |
| |
|
| | def test_cache_complex(): |
| | """ |
| | Test caching on a complicated expression with multiple symbols appearing |
| | multiple times. |
| | """ |
| | expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y) |
| | symbol_names = {s.name for s in expr.free_symbols} |
| | expr_t = aesara_code_(expr) |
| |
|
| | |
| | |
| | seen = set() |
| | for v in aesara.graph.basic.ancestors([expr_t]): |
| | |
| | if v.owner is None and not isinstance(v, aesara.graph.basic.Constant): |
| | |
| | assert v.name in symbol_names |
| | assert v.name not in seen |
| | seen.add(v.name) |
| |
|
| | |
| | assert seen == symbol_names |
| |
|
| |
|
| | def test_Piecewise(): |
| | |
| | expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) |
| | result = aesara_code_(expr) |
| | assert result.owner.op == aet.switch |
| |
|
| | expected = aet.switch(xt<0, 0, aet.switch(xt<2, xt, 1)) |
| | assert theq(result, expected) |
| |
|
| | expr = sy.Piecewise((x, x < 0)) |
| | result = aesara_code_(expr) |
| | expected = aet.switch(xt < 0, xt, np.nan) |
| | assert theq(result, expected) |
| |
|
| | expr = sy.Piecewise((0, sy.And(x>0, x<2)), \ |
| | (x, sy.Or(x>2, x<0))) |
| | result = aesara_code_(expr) |
| | expected = aet.switch(aet.and_(xt>0,xt<2), 0, \ |
| | aet.switch(aet.or_(xt>2, xt<0), xt, np.nan)) |
| | assert theq(result, expected) |
| |
|
| |
|
| | def test_Relationals(): |
| | assert theq(aesara_code_(sy.Eq(x, y)), aet.eq(xt, yt)) |
| | |
| | assert theq(aesara_code_(x > y), xt > yt) |
| | assert theq(aesara_code_(x < y), xt < yt) |
| | assert theq(aesara_code_(x >= y), xt >= yt) |
| | assert theq(aesara_code_(x <= y), xt <= yt) |
| |
|
| |
|
| | def test_complexfunctions(): |
| | dtypes = {x:'complex128', y:'complex128'} |
| | with warns_deprecated_sympy(): |
| | xt, yt = aesara_code(x, dtypes=dtypes), aesara_code(y, dtypes=dtypes) |
| | from sympy.functions.elementary.complexes import conjugate |
| | from aesara.tensor import as_tensor_variable as atv |
| | from aesara.tensor import complex as cplx |
| | with warns_deprecated_sympy(): |
| | assert theq(aesara_code(y*conjugate(x), dtypes=dtypes), yt*(xt.conj())) |
| | assert theq(aesara_code((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1))) |
| |
|
| |
|
| | def test_constantfunctions(): |
| | with warns_deprecated_sympy(): |
| | tf = aesara_function([],[1+1j]) |
| | assert(tf()==1+1j) |
| |
|