| from __future__ import annotations |
|
|
| from warnings import warn |
| import inspect |
| from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning |
| from .utils import expand_tuples |
| import itertools as itl |
|
|
|
|
| class MDNotImplementedError(NotImplementedError): |
| """ A NotImplementedError for multiple dispatch """ |
|
|
|
|
| |
|
|
| def ambiguity_warn(dispatcher, ambiguities): |
| """ Raise warning when ambiguity is detected |
| |
| Parameters |
| ---------- |
| dispatcher : Dispatcher |
| The dispatcher on which the ambiguity was detected |
| ambiguities : set |
| Set of type signature pairs that are ambiguous within this dispatcher |
| |
| See Also: |
| Dispatcher.add |
| warning_text |
| """ |
| warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) |
|
|
|
|
| class RaiseNotImplementedError: |
| """Raise ``NotImplementedError`` when called.""" |
|
|
| def __init__(self, dispatcher): |
| self.dispatcher = dispatcher |
|
|
| def __call__(self, *args, **kwargs): |
| types = tuple(type(a) for a in args) |
| raise NotImplementedError( |
| "Ambiguous signature for %s: <%s>" % ( |
| self.dispatcher.name, str_signature(types) |
| )) |
|
|
| def ambiguity_register_error_ignore_dup(dispatcher, ambiguities): |
| """ |
| If super signature for ambiguous types is duplicate types, ignore it. |
| Else, register instance of ``RaiseNotImplementedError`` for ambiguous types. |
| |
| Parameters |
| ---------- |
| dispatcher : Dispatcher |
| The dispatcher on which the ambiguity was detected |
| ambiguities : set |
| Set of type signature pairs that are ambiguous within this dispatcher |
| |
| See Also: |
| Dispatcher.add |
| ambiguity_warn |
| """ |
| for amb in ambiguities: |
| signature = tuple(super_signature(amb)) |
| if len(set(signature)) == 1: |
| continue |
| dispatcher.add( |
| signature, RaiseNotImplementedError(dispatcher), |
| on_ambiguity=ambiguity_register_error_ignore_dup |
| ) |
|
|
| |
|
|
|
|
| _unresolved_dispatchers: set[Dispatcher] = set() |
| _resolve = [True] |
|
|
|
|
| def halt_ordering(): |
| _resolve[0] = False |
|
|
|
|
| def restart_ordering(on_ambiguity=ambiguity_warn): |
| _resolve[0] = True |
| while _unresolved_dispatchers: |
| dispatcher = _unresolved_dispatchers.pop() |
| dispatcher.reorder(on_ambiguity=on_ambiguity) |
|
|
|
|
| class Dispatcher: |
| """ Dispatch methods based on type signature |
| |
| Use ``dispatch`` to add implementations |
| |
| Examples |
| -------- |
| |
| >>> from sympy.multipledispatch import dispatch |
| >>> @dispatch(int) |
| ... def f(x): |
| ... return x + 1 |
| |
| >>> @dispatch(float) |
| ... def f(x): # noqa: F811 |
| ... return x - 1 |
| |
| >>> f(3) |
| 4 |
| >>> f(3.0) |
| 2.0 |
| """ |
| __slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc' |
|
|
| def __init__(self, name, doc=None): |
| self.name = self.__name__ = name |
| self.funcs = {} |
| self._cache = {} |
| self.ordering = [] |
| self.doc = doc |
|
|
| def register(self, *types, **kwargs): |
| """ Register dispatcher with new implementation |
| |
| >>> from sympy.multipledispatch.dispatcher import Dispatcher |
| >>> f = Dispatcher('f') |
| >>> @f.register(int) |
| ... def inc(x): |
| ... return x + 1 |
| |
| >>> @f.register(float) |
| ... def dec(x): |
| ... return x - 1 |
| |
| >>> @f.register(list) |
| ... @f.register(tuple) |
| ... def reverse(x): |
| ... return x[::-1] |
| |
| >>> f(1) |
| 2 |
| |
| >>> f(1.0) |
| 0.0 |
| |
| >>> f([1, 2, 3]) |
| [3, 2, 1] |
| """ |
| def _(func): |
| self.add(types, func, **kwargs) |
| return func |
| return _ |
|
|
| @classmethod |
| def get_func_params(cls, func): |
| if hasattr(inspect, "signature"): |
| sig = inspect.signature(func) |
| return sig.parameters.values() |
|
|
| @classmethod |
| def get_func_annotations(cls, func): |
| """ Get annotations of function positional parameters |
| """ |
| params = cls.get_func_params(func) |
| if params: |
| Parameter = inspect.Parameter |
|
|
| params = (param for param in params |
| if param.kind in |
| (Parameter.POSITIONAL_ONLY, |
| Parameter.POSITIONAL_OR_KEYWORD)) |
|
|
| annotations = tuple( |
| param.annotation |
| for param in params) |
|
|
| if not any(ann is Parameter.empty for ann in annotations): |
| return annotations |
|
|
| def add(self, signature, func, on_ambiguity=ambiguity_warn): |
| """ Add new types/method pair to dispatcher |
| |
| >>> from sympy.multipledispatch import Dispatcher |
| >>> D = Dispatcher('add') |
| >>> D.add((int, int), lambda x, y: x + y) |
| >>> D.add((float, float), lambda x, y: x + y) |
| |
| >>> D(1, 2) |
| 3 |
| >>> D(1, 2.0) |
| Traceback (most recent call last): |
| ... |
| NotImplementedError: Could not find signature for add: <int, float> |
| |
| When ``add`` detects a warning it calls the ``on_ambiguity`` callback |
| with a dispatcher/itself, and a set of ambiguous type signature pairs |
| as inputs. See ``ambiguity_warn`` for an example. |
| """ |
| |
| if not signature: |
| annotations = self.get_func_annotations(func) |
| if annotations: |
| signature = annotations |
|
|
| |
| if any(isinstance(typ, tuple) for typ in signature): |
| for typs in expand_tuples(signature): |
| self.add(typs, func, on_ambiguity) |
| return |
|
|
| for typ in signature: |
| if not isinstance(typ, type): |
| str_sig = ', '.join(c.__name__ if isinstance(c, type) |
| else str(c) for c in signature) |
| raise TypeError("Tried to dispatch on non-type: %s\n" |
| "In signature: <%s>\n" |
| "In function: %s" % |
| (typ, str_sig, self.name)) |
|
|
| self.funcs[signature] = func |
| self.reorder(on_ambiguity=on_ambiguity) |
| self._cache.clear() |
|
|
| def reorder(self, on_ambiguity=ambiguity_warn): |
| if _resolve[0]: |
| self.ordering = ordering(self.funcs) |
| amb = ambiguities(self.funcs) |
| if amb: |
| on_ambiguity(self, amb) |
| else: |
| _unresolved_dispatchers.add(self) |
|
|
| def __call__(self, *args, **kwargs): |
| types = tuple([type(arg) for arg in args]) |
| try: |
| func = self._cache[types] |
| except KeyError: |
| func = self.dispatch(*types) |
| if not func: |
| raise NotImplementedError( |
| 'Could not find signature for %s: <%s>' % |
| (self.name, str_signature(types))) |
| self._cache[types] = func |
| try: |
| return func(*args, **kwargs) |
|
|
| except MDNotImplementedError: |
| funcs = self.dispatch_iter(*types) |
| next(funcs) |
| for func in funcs: |
| try: |
| return func(*args, **kwargs) |
| except MDNotImplementedError: |
| pass |
| raise NotImplementedError("Matching functions for " |
| "%s: <%s> found, but none completed successfully" |
| % (self.name, str_signature(types))) |
|
|
| def __str__(self): |
| return "<dispatched %s>" % self.name |
| __repr__ = __str__ |
|
|
| def dispatch(self, *types): |
| """ Deterimine appropriate implementation for this type signature |
| |
| This method is internal. Users should call this object as a function. |
| Implementation resolution occurs within the ``__call__`` method. |
| |
| >>> from sympy.multipledispatch import dispatch |
| >>> @dispatch(int) |
| ... def inc(x): |
| ... return x + 1 |
| |
| >>> implementation = inc.dispatch(int) |
| >>> implementation(3) |
| 4 |
| |
| >>> print(inc.dispatch(float)) |
| None |
| |
| See Also: |
| ``sympy.multipledispatch.conflict`` - module to determine resolution order |
| """ |
|
|
| if types in self.funcs: |
| return self.funcs[types] |
|
|
| try: |
| return next(self.dispatch_iter(*types)) |
| except StopIteration: |
| return None |
|
|
| def dispatch_iter(self, *types): |
| n = len(types) |
| for signature in self.ordering: |
| if len(signature) == n and all(map(issubclass, types, signature)): |
| result = self.funcs[signature] |
| yield result |
|
|
| def resolve(self, types): |
| """ Deterimine appropriate implementation for this type signature |
| |
| .. deprecated:: 0.4.4 |
| Use ``dispatch(*types)`` instead |
| """ |
| warn("resolve() is deprecated, use dispatch(*types)", |
| DeprecationWarning) |
|
|
| return self.dispatch(*types) |
|
|
| def __getstate__(self): |
| return {'name': self.name, |
| 'funcs': self.funcs} |
|
|
| def __setstate__(self, d): |
| self.name = d['name'] |
| self.funcs = d['funcs'] |
| self.ordering = ordering(self.funcs) |
| self._cache = {} |
|
|
| @property |
| def __doc__(self): |
| docs = ["Multiply dispatched method: %s" % self.name] |
|
|
| if self.doc: |
| docs.append(self.doc) |
|
|
| other = [] |
| for sig in self.ordering[::-1]: |
| func = self.funcs[sig] |
| if func.__doc__: |
| s = 'Inputs: <%s>\n' % str_signature(sig) |
| s += '-' * len(s) + '\n' |
| s += func.__doc__.strip() |
| docs.append(s) |
| else: |
| other.append(str_signature(sig)) |
|
|
| if other: |
| docs.append('Other signatures:\n ' + '\n '.join(other)) |
|
|
| return '\n\n'.join(docs) |
|
|
| def _help(self, *args): |
| return self.dispatch(*map(type, args)).__doc__ |
|
|
| def help(self, *args, **kwargs): |
| """ Print docstring for the function corresponding to inputs """ |
| print(self._help(*args)) |
|
|
| def _source(self, *args): |
| func = self.dispatch(*map(type, args)) |
| if not func: |
| raise TypeError("No function found") |
| return source(func) |
|
|
| def source(self, *args, **kwargs): |
| """ Print source code for the function corresponding to inputs """ |
| print(self._source(*args)) |
|
|
|
|
| def source(func): |
| s = 'File: %s\n\n' % inspect.getsourcefile(func) |
| s = s + inspect.getsource(func) |
| return s |
|
|
|
|
| class MethodDispatcher(Dispatcher): |
| """ Dispatch methods based on type signature |
| |
| See Also: |
| Dispatcher |
| """ |
|
|
| @classmethod |
| def get_func_params(cls, func): |
| if hasattr(inspect, "signature"): |
| sig = inspect.signature(func) |
| return itl.islice(sig.parameters.values(), 1, None) |
|
|
| def __get__(self, instance, owner): |
| self.obj = instance |
| self.cls = owner |
| return self |
|
|
| def __call__(self, *args, **kwargs): |
| types = tuple([type(arg) for arg in args]) |
| func = self.dispatch(*types) |
| if not func: |
| raise NotImplementedError('Could not find signature for %s: <%s>' % |
| (self.name, str_signature(types))) |
| return func(self.obj, *args, **kwargs) |
|
|
|
|
| def str_signature(sig): |
| """ String representation of type signature |
| |
| >>> from sympy.multipledispatch.dispatcher import str_signature |
| >>> str_signature((int, float)) |
| 'int, float' |
| """ |
| return ', '.join(cls.__name__ for cls in sig) |
|
|
|
|
| def warning_text(name, amb): |
| """ The text for ambiguity warnings """ |
| text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) |
| text += "The following signatures may result in ambiguous behavior:\n" |
| for pair in amb: |
| text += "\t" + \ |
| ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" |
| text += "\n\nConsider making the following additions:\n\n" |
| text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) |
| + ')\ndef %s(...)' % name for s in amb]) |
| return text |
|
|