| |
|
| |
|
| |
|
| | import contextlib |
| | import copy |
| | import threading |
| |
|
| | _threadlocal_scope = threading.local() |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def arg_scope(single_helper_or_list, **kwargs): |
| | global _threadlocal_scope |
| | if not isinstance(single_helper_or_list, list): |
| | assert callable(single_helper_or_list), \ |
| | "arg_scope is only supporting single or a list of helper functions." |
| | single_helper_or_list = [single_helper_or_list] |
| | old_scope = copy.deepcopy(get_current_scope()) |
| | for helper in single_helper_or_list: |
| | assert callable(helper), \ |
| | "arg_scope is only supporting a list of callable helper functions." |
| | helper_key = helper.__name__ |
| | if helper_key not in old_scope: |
| | _threadlocal_scope.current_scope[helper_key] = {} |
| | _threadlocal_scope.current_scope[helper_key].update(kwargs) |
| |
|
| | yield |
| | _threadlocal_scope.current_scope = old_scope |
| |
|
| |
|
| | def get_current_scope(): |
| | global _threadlocal_scope |
| | if not hasattr(_threadlocal_scope, "current_scope"): |
| | _threadlocal_scope.current_scope = {} |
| | return _threadlocal_scope.current_scope |
| |
|