| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import absolute_import |
| | from __future__ import print_function |
| | from __future__ import division |
| |
|
| | import inspect |
| | import importlib |
| | import re |
| |
|
| | try: |
| | from docstring_parser import parse as doc_parse |
| | except Exception: |
| |
|
| | def doc_parse(*args): |
| | pass |
| |
|
| |
|
| | try: |
| | from typeguard import check_type |
| | except Exception: |
| |
|
| | def check_type(*args): |
| | pass |
| |
|
| |
|
| | __all__ = ['SchemaValue', 'SchemaDict', 'SharedConfig', 'extract_schema'] |
| |
|
| |
|
| | class SchemaValue(object): |
| | def __init__(self, name, doc='', type=None): |
| | super(SchemaValue, self).__init__() |
| | self.name = name |
| | self.doc = doc |
| | self.type = type |
| |
|
| | def set_default(self, value): |
| | self.default = value |
| |
|
| | def has_default(self): |
| | return hasattr(self, 'default') |
| |
|
| |
|
| | class SchemaDict(dict): |
| | def __init__(self, **kwargs): |
| | super(SchemaDict, self).__init__() |
| | self.schema = {} |
| | self.strict = False |
| | self.doc = "" |
| | self.update(kwargs) |
| |
|
| | def __setitem__(self, key, value): |
| | |
| | if isinstance(value, dict) and key in self and isinstance(self[key], |
| | SchemaDict): |
| | self[key].update(value) |
| | else: |
| | super(SchemaDict, self).__setitem__(key, value) |
| |
|
| | def __missing__(self, key): |
| | if self.has_default(key): |
| | return self.schema[key].default |
| | elif key in self.schema: |
| | return self.schema[key] |
| | else: |
| | raise KeyError(key) |
| |
|
| | def copy(self): |
| | newone = SchemaDict() |
| | newone.__dict__.update(self.__dict__) |
| | newone.update(self) |
| | return newone |
| |
|
| | def set_schema(self, key, value): |
| | assert isinstance(value, SchemaValue) |
| | self.schema[key] = value |
| |
|
| | def set_strict(self, strict): |
| | self.strict = strict |
| |
|
| | def has_default(self, key): |
| | return key in self.schema and self.schema[key].has_default() |
| |
|
| | def is_default(self, key): |
| | if not self.has_default(key): |
| | return False |
| | if hasattr(self[key], '__dict__'): |
| | return True |
| | else: |
| | return key not in self or self[key] == self.schema[key].default |
| |
|
| | def find_default_keys(self): |
| | return [ |
| | k for k in list(self.keys()) + list(self.schema.keys()) |
| | if self.is_default(k) |
| | ] |
| |
|
| | def mandatory(self): |
| | return any([k for k in self.schema.keys() if not self.has_default(k)]) |
| |
|
| | def find_missing_keys(self): |
| | missing = [ |
| | k for k in self.schema.keys() |
| | if k not in self and not self.has_default(k) |
| | ] |
| | placeholders = [k for k in self if self[k] in ('<missing>', '<value>')] |
| | return missing + placeholders |
| |
|
| | def find_extra_keys(self): |
| | return list(set(self.keys()) - set(self.schema.keys())) |
| |
|
| | def find_mismatch_keys(self): |
| | mismatch_keys = [] |
| | for arg in self.schema.values(): |
| | if arg.type is not None: |
| | try: |
| | check_type("{}.{}".format(self.name, arg.name), |
| | self[arg.name], arg.type) |
| | except Exception: |
| | mismatch_keys.append(arg.name) |
| | return mismatch_keys |
| |
|
| | def validate(self): |
| | missing_keys = self.find_missing_keys() |
| | if missing_keys: |
| | raise ValueError("Missing param for class<{}>: {}".format( |
| | self.name, ", ".join(missing_keys))) |
| | extra_keys = self.find_extra_keys() |
| | if extra_keys and self.strict: |
| | raise ValueError("Extraneous param for class<{}>: {}".format( |
| | self.name, ", ".join(extra_keys))) |
| | mismatch_keys = self.find_mismatch_keys() |
| | if mismatch_keys: |
| | raise TypeError("Wrong param type for class<{}>: {}".format( |
| | self.name, ", ".join(mismatch_keys))) |
| |
|
| |
|
| | class SharedConfig(object): |
| | """ |
| | Representation class for `__shared__` annotations, which work as follows: |
| | |
| | - if `key` is set for the module in config file, its value will take |
| | precedence |
| | - if `key` is not set for the module but present in the config file, its |
| | value will be used |
| | - otherwise, use the provided `default_value` as fallback |
| | |
| | Args: |
| | key: config[key] will be injected |
| | default_value: fallback value |
| | """ |
| |
|
| | def __init__(self, key, default_value=None): |
| | super(SharedConfig, self).__init__() |
| | self.key = key |
| | self.default_value = default_value |
| |
|
| |
|
| | def extract_schema(cls): |
| | """ |
| | Extract schema from a given class |
| | |
| | Args: |
| | cls (type): Class from which to extract. |
| | |
| | Returns: |
| | schema (SchemaDict): Extracted schema. |
| | """ |
| | ctor = cls.__init__ |
| | |
| | if hasattr(inspect, 'getfullargspec'): |
| | argspec = inspect.getfullargspec(ctor) |
| | annotations = argspec.annotations |
| | has_kwargs = argspec.varkw is not None |
| | else: |
| | argspec = inspect.getfullargspec(ctor) |
| | |
| | |
| | |
| | annotations = getattr(ctor, '__annotations__', {}) |
| | has_kwargs = argspec.varkw is not None |
| |
|
| | names = [arg for arg in argspec.args if arg != 'self'] |
| | defaults = argspec.defaults |
| | num_defaults = argspec.defaults is not None and len(argspec.defaults) or 0 |
| | num_required = len(names) - num_defaults |
| |
|
| | docs = cls.__doc__ |
| | if docs is None and getattr(cls, '__category__', None) == 'op': |
| | docs = cls.__call__.__doc__ |
| | try: |
| | docstring = doc_parse(docs) |
| | except Exception: |
| | docstring = None |
| |
|
| | if docstring is None: |
| | comments = {} |
| | else: |
| | comments = {} |
| | for p in docstring.params: |
| | match_obj = re.match('^([a-zA-Z_]+[a-zA-Z_0-9]*).*', p.arg_name) |
| | if match_obj is not None: |
| | comments[match_obj.group(1)] = p.description |
| |
|
| | schema = SchemaDict() |
| | schema.name = cls.__name__ |
| | schema.doc = "" |
| | if docs is not None: |
| | start_pos = docs[0] == '\n' and 1 or 0 |
| | schema.doc = docs[start_pos:].split("\n")[0].strip() |
| | |
| | if '**' == schema.doc[:2] and '**' == schema.doc[-2:]: |
| | schema.doc = schema.doc[2:-2].strip() |
| | schema.category = hasattr(cls, '__category__') and getattr( |
| | cls, '__category__') or 'module' |
| | schema.strict = not has_kwargs |
| | schema.pymodule = importlib.import_module(cls.__module__) |
| | schema.inject = getattr(cls, '__inject__', []) |
| | schema.shared = getattr(cls, '__shared__', []) |
| | for idx, name in enumerate(names): |
| | comment = name in comments and comments[name] or name |
| | if name in schema.inject: |
| | type_ = None |
| | else: |
| | type_ = name in annotations and annotations[name] or None |
| | value_schema = SchemaValue(name, comment, type_) |
| | if name in schema.shared: |
| | assert idx >= num_required, "shared config must have default value" |
| | default = defaults[idx - num_required] |
| | value_schema.set_default(SharedConfig(name, default)) |
| | elif idx >= num_required: |
| | default = defaults[idx - num_required] |
| | value_schema.set_default(default) |
| | schema.set_schema(name, value_schema) |
| |
|
| | return schema |
| |
|