File size: 3,543 Bytes
f4cade0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# mypy: allow-untyped-defs
from .core import reify, unify # type: ignore[attr-defined]
from .unification_tools import first, groupby # type: ignore[import]
from .utils import _toposort, freeze
from .variable import isvar
class Dispatcher:
def __init__(self, name):
self.name = name
self.funcs = {}
self.ordering = []
def add(self, signature, func):
self.funcs[freeze(signature)] = func
self.ordering = ordering(self.funcs)
def __call__(self, *args, **kwargs):
func, _ = self.resolve(args)
return func(*args, **kwargs)
def resolve(self, args):
n = len(args)
for signature in self.ordering:
if len(signature) != n:
continue
s = unify(freeze(args), signature)
if s is not False:
result = self.funcs[signature]
return result, s
raise NotImplementedError(
"No match found. \nKnown matches: "
+ str(self.ordering)
+ "\nInput: "
+ str(args)
)
def register(self, *signature):
def _(func):
self.add(signature, func)
return self
return _
class VarDispatcher(Dispatcher):
"""A dispatcher that calls functions with variable names
>>> # xdoctest: +SKIP
>>> d = VarDispatcher("d")
>>> x = var("x")
>>> @d.register("inc", x)
... def f(x):
... return x + 1
>>> @d.register("double", x)
... def f(x):
... return x * 2
>>> d("inc", 10)
11
>>> d("double", 10)
20
"""
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
d = {k.token: v for k, v in s.items()}
return func(**d)
global_namespace = {} # type: ignore[var-annotated]
def match(*signature, **kwargs):
namespace = kwargs.get("namespace", global_namespace)
dispatcher = kwargs.get("Dispatcher", Dispatcher)
def _(func):
name = func.__name__
if name not in namespace:
namespace[name] = dispatcher(name)
d = namespace[name]
d.add(signature, func)
return d
return _
def supercedes(a, b):
"""``a`` is a more specific match than ``b``"""
if isvar(b) and not isvar(a):
return True
s = unify(a, b)
if s is False:
return False
s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
if reify(a, s) == a:
return True
if reify(b, s) == b:
return False
# Taken from multipledispatch
def edge(a, b, tie_breaker=hash):
"""A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
if supercedes(a, b):
if supercedes(b, a):
return tie_breaker(a) > tie_breaker(b)
else:
return True
return False
# Taken from multipledispatch
def ordering(signatures):
"""A sane ordering of signatures to check, first to last
Topological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(first, edges)
for s in signatures:
if s not in edges:
edges[s] = []
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
return _toposort(edges)
|