| | |
| | |
| |
|
| | import inspect |
| | import threading |
| | import functools |
| |
|
| |
|
| | class _ContextInfo(object): |
| | def __init__(self, cls, allow_default): |
| | self.cls = cls |
| | self.allow_default = allow_default |
| | self._local_stack = threading.local() |
| |
|
| | @property |
| | def _stack(self): |
| | if not hasattr(self._local_stack, 'obj'): |
| | self._local_stack.obj = [] |
| | return self._local_stack.obj |
| |
|
| | def enter(self, value): |
| | self._stack.append(value) |
| |
|
| | def exit(self, value): |
| | assert len(self._stack) > 0, 'Context %s is empty.' % self.cls |
| | assert self._stack.pop() == value |
| |
|
| | def get_active(self, required=True): |
| | if len(self._stack) == 0: |
| | if not required: |
| | return None |
| | assert self.allow_default, ( |
| | 'Context %s is required but none is active.' % self.cls) |
| | self.enter(self.cls()) |
| | return self._stack[-1] |
| |
|
| |
|
| | class _ContextRegistry(object): |
| | def __init__(self): |
| | self._ctxs = {} |
| |
|
| | def get(self, cls): |
| | if cls not in self._ctxs: |
| | assert issubclass(cls, Managed), "must be a context managed class, got {}".format(cls) |
| | self._ctxs[cls] = _ContextInfo(cls, allow_default=issubclass(cls, DefaultManaged)) |
| | return self._ctxs[cls] |
| |
|
| |
|
| | _CONTEXT_REGISTRY = _ContextRegistry() |
| |
|
| |
|
| | def _context_registry(): |
| | global _CONTEXT_REGISTRY |
| | return _CONTEXT_REGISTRY |
| |
|
| |
|
| | def _get_managed_classes(obj): |
| | return [ |
| | cls for cls in inspect.getmro(obj.__class__) |
| | if issubclass(cls, Managed) and cls != Managed and cls != DefaultManaged |
| | ] |
| |
|
| |
|
| |
|
| | class Managed(object): |
| | """ |
| | Managed makes the inheritted class a context managed class. |
| | |
| | class Foo(Managed): ... |
| | |
| | with Foo() as f: |
| | assert f == Foo.current() |
| | """ |
| |
|
| | @classmethod |
| | def current(cls, value=None, required=True): |
| | ctx_info = _context_registry().get(cls) |
| | if value is not None: |
| | assert isinstance(value, cls), ( |
| | 'Wrong context type. Expected: %s, got %s.' % (cls, type(value))) |
| | return value |
| | return ctx_info.get_active(required=required) |
| |
|
| | def __enter__(self): |
| | for cls in _get_managed_classes(self): |
| | _context_registry().get(cls).enter(self) |
| | return self |
| |
|
| | def __exit__(self, *args): |
| | for cls in _get_managed_classes(self): |
| | _context_registry().get(cls).exit(self) |
| |
|
| | def __call__(self, func): |
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | with self: |
| | return func(*args, **kwargs) |
| | return wrapper |
| |
|
| |
|
| | class DefaultManaged(Managed): |
| | """ |
| | DefaultManaged is similar to Managed but if there is no parent when |
| | current() is called it makes a new one. |
| | """ |
| | pass |
| |
|