| |
|
|
| from collections import OrderedDict |
| from textwrap import dedent |
| import operator |
|
|
| from . import ExprNodes |
| from . import Nodes |
| from . import PyrexTypes |
| from . import Builtin |
| from . import Naming |
| from .Errors import error, warning |
| from .Code import UtilityCode, TempitaUtilityCode, PyxCodeWriter |
| from .Visitor import VisitorTransform |
| from .StringEncoding import EncodedString |
| from .TreeFragment import TreeFragment |
| from .ParseTreeTransforms import NormalizeTree, SkipDeclarations |
| from .Options import copy_inherited_directives |
|
|
| _dataclass_loader_utilitycode = None |
|
|
| def make_dataclasses_module_callnode(pos): |
| global _dataclass_loader_utilitycode |
| if not _dataclass_loader_utilitycode: |
| python_utility_code = UtilityCode.load_cached("Dataclasses_fallback", "Dataclasses.py") |
| python_utility_code = EncodedString(python_utility_code.impl) |
| _dataclass_loader_utilitycode = TempitaUtilityCode.load( |
| "SpecificModuleLoader", "Dataclasses.c", |
| context={'cname': "dataclasses", 'py_code': python_utility_code.as_c_string_literal()}) |
| return ExprNodes.PythonCapiCallNode( |
| pos, "__Pyx_Load_dataclasses_Module", |
| PyrexTypes.CFuncType(PyrexTypes.py_object_type, []), |
| utility_code=_dataclass_loader_utilitycode, |
| args=[], |
| ) |
|
|
| def make_dataclass_call_helper(pos, callable, kwds): |
| utility_code = UtilityCode.load_cached("DataclassesCallHelper", "Dataclasses.c") |
| func_type = PyrexTypes.CFuncType( |
| PyrexTypes.py_object_type, [ |
| PyrexTypes.CFuncTypeArg("callable", PyrexTypes.py_object_type, None), |
| PyrexTypes.CFuncTypeArg("kwds", PyrexTypes.py_object_type, None) |
| ], |
| ) |
| return ExprNodes.PythonCapiCallNode( |
| pos, |
| function_name="__Pyx_DataclassesCallHelper", |
| func_type=func_type, |
| utility_code=utility_code, |
| args=[callable, kwds], |
| ) |
|
|
|
|
| class RemoveAssignmentsToNames(VisitorTransform, SkipDeclarations): |
| """ |
| Cython (and Python) normally treats |
| |
| class A: |
| x = 1 |
| |
| as generating a class attribute. However for dataclasses the `= 1` should be interpreted as |
| a default value to initialize an instance attribute with. |
| This transform therefore removes the `x=1` assignment so that the class attribute isn't |
| generated, while recording what it has removed so that it can be used in the initialization. |
| """ |
| def __init__(self, names): |
| super().__init__() |
| self.names = names |
| self.removed_assignments = {} |
|
|
| def visit_CClassNode(self, node): |
| self.visitchildren(node) |
| return node |
|
|
| def visit_PyClassNode(self, node): |
| return node |
|
|
| def visit_FuncDefNode(self, node): |
| return node |
|
|
| def visit_SingleAssignmentNode(self, node): |
| if node.lhs.is_name and node.lhs.name in self.names: |
| if node.lhs.name in self.removed_assignments: |
| warning(node.pos, ("Multiple assignments for '%s' in dataclass; " |
| "using most recent") % node.lhs.name, 1) |
| self.removed_assignments[node.lhs.name] = node.rhs |
| return [] |
| return node |
|
|
| |
| |
|
|
| def visit_Node(self, node): |
| self.visitchildren(node) |
| return node |
|
|
|
|
| class TemplateCode: |
| """ |
| Adds the ability to keep track of placeholder argument names to PyxCodeWriter. |
| |
| Also adds extra_stats which are nodes bundled at the end when this |
| is converted to a tree. |
| """ |
| _placeholder_count = 0 |
|
|
| def __init__(self, writer=None, placeholders=None, extra_stats=None): |
| self.writer = PyxCodeWriter() if writer is None else writer |
| self.placeholders = {} if placeholders is None else placeholders |
| self.extra_stats = [] if extra_stats is None else extra_stats |
|
|
| def add_code_line(self, code_line): |
| self.writer.putln(code_line) |
|
|
| def add_code_chunk(self, code_chunk): |
| self.writer.put_chunk(code_chunk) |
|
|
| def reset(self): |
| |
| |
| self.writer.reset() |
|
|
| def empty(self): |
| return self.writer.empty() |
|
|
| def indent(self): |
| self.writer.indent() |
|
|
| def dedent(self): |
| self.writer.dedent() |
|
|
| def indenter(self, block_opener_line): |
| return self.writer.indenter(block_opener_line) |
|
|
| def new_placeholder(self, field_names, value): |
| name = self._new_placeholder_name(field_names) |
| self.placeholders[name] = value |
| return name |
|
|
| def add_extra_statements(self, statements): |
| if self.extra_stats is None: |
| assert False, "Can only use add_extra_statements on top-level writer" |
| self.extra_stats.extend(statements) |
|
|
| def _new_placeholder_name(self, field_names): |
| while True: |
| name = f"DATACLASS_PLACEHOLDER_{self._placeholder_count:d}" |
| if (name not in self.placeholders |
| and name not in field_names): |
| |
| |
| break |
| self._placeholder_count += 1 |
| return name |
|
|
| def generate_tree(self, level='c_class'): |
| stat_list_node = TreeFragment( |
| self.writer.getvalue(), |
| level=level, |
| pipeline=[NormalizeTree(None)], |
| ).substitute(self.placeholders) |
|
|
| stat_list_node.stats += self.extra_stats |
| return stat_list_node |
|
|
| def insertion_point(self): |
| new_writer = self.writer.insertion_point() |
| return TemplateCode( |
| writer=new_writer, |
| placeholders=self.placeholders, |
| extra_stats=self.extra_stats |
| ) |
|
|
|
|
| class _MISSING_TYPE: |
| pass |
| MISSING = _MISSING_TYPE() |
|
|
|
|
| class Field: |
| """ |
| Field is based on the dataclasses.field class from the standard library module. |
| It is used internally during the generation of Cython dataclasses to keep track |
| of the settings for individual attributes. |
| |
| Attributes of this class are stored as nodes so they can be used in code construction |
| more readily (i.e. we store BoolNode rather than bool) |
| """ |
| default = MISSING |
| default_factory = MISSING |
| private = False |
|
|
| literal_keys = ("repr", "hash", "init", "compare", "metadata") |
|
|
| |
| def __init__(self, pos, default=MISSING, default_factory=MISSING, |
| repr=None, hash=None, init=None, |
| compare=None, metadata=None, |
| is_initvar=False, is_classvar=False, |
| **additional_kwds): |
| if default is not MISSING: |
| self.default = default |
| if default_factory is not MISSING: |
| self.default_factory = default_factory |
| self.repr = repr or ExprNodes.BoolNode(pos, value=True) |
| self.hash = hash or ExprNodes.NoneNode(pos) |
| self.init = init or ExprNodes.BoolNode(pos, value=True) |
| self.compare = compare or ExprNodes.BoolNode(pos, value=True) |
| self.metadata = metadata or ExprNodes.NoneNode(pos) |
| self.is_initvar = is_initvar |
| self.is_classvar = is_classvar |
|
|
| for k, v in additional_kwds.items(): |
| |
| error(v.pos, "cython.dataclasses.field() got an unexpected keyword argument '%s'" % k) |
|
|
| for field_name in self.literal_keys: |
| field_value = getattr(self, field_name) |
| if not field_value.is_literal: |
| error(field_value.pos, |
| "cython.dataclasses.field parameter '%s' must be a literal value" % field_name) |
|
|
| def iterate_record_node_arguments(self): |
| for key in (self.literal_keys + ('default', 'default_factory')): |
| value = getattr(self, key) |
| if value is not MISSING: |
| yield key, value |
|
|
|
|
| def process_class_get_fields(node): |
| var_entries = node.scope.var_entries |
| |
| var_entries = sorted(var_entries, key=operator.attrgetter('pos')) |
| var_names = [entry.name for entry in var_entries] |
|
|
| |
| transform = RemoveAssignmentsToNames(var_names) |
| transform(node) |
| default_value_assignments = transform.removed_assignments |
|
|
| base_type = node.base_type |
| fields = OrderedDict() |
| while base_type: |
| if base_type.is_external or not base_type.scope.implemented: |
| warning(node.pos, "Cannot reliably handle Cython dataclasses with base types " |
| "in external modules since it is not possible to tell what fields they have", 2) |
| if base_type.dataclass_fields: |
| fields = base_type.dataclass_fields.copy() |
| break |
| base_type = base_type.base_type |
|
|
| for entry in var_entries: |
| name = entry.name |
| is_initvar = entry.declared_with_pytyping_modifier("dataclasses.InitVar") |
| |
| |
| is_classvar = entry.declared_with_pytyping_modifier("typing.ClassVar") |
| if name in default_value_assignments: |
| assignment = default_value_assignments[name] |
| if (isinstance(assignment, ExprNodes.CallNode) and ( |
| assignment.function.as_cython_attribute() == "dataclasses.field" or |
| Builtin.exprnode_to_known_standard_library_name( |
| assignment.function, node.scope) == "dataclasses.field")): |
| |
| |
| valid_general_call = (isinstance(assignment, ExprNodes.GeneralCallNode) |
| and isinstance(assignment.positional_args, ExprNodes.TupleNode) |
| and not assignment.positional_args.args |
| and (assignment.keyword_args is None or isinstance(assignment.keyword_args, ExprNodes.DictNode))) |
| valid_simple_call = (isinstance(assignment, ExprNodes.SimpleCallNode) and not assignment.args) |
| if not (valid_general_call or valid_simple_call): |
| error(assignment.pos, "Call to 'cython.dataclasses.field' must only consist " |
| "of compile-time keyword arguments") |
| continue |
| keyword_args = assignment.keyword_args.as_python_dict() if valid_general_call and assignment.keyword_args else {} |
| if 'default' in keyword_args and 'default_factory' in keyword_args: |
| error(assignment.pos, "cannot specify both default and default_factory") |
| continue |
| field = Field(node.pos, **keyword_args) |
| else: |
| if assignment.type in [Builtin.list_type, Builtin.dict_type, Builtin.set_type]: |
| |
| |
| |
| error(assignment.pos, "mutable default <class '{}'> for field {} is not allowed: " |
| "use default_factory".format(assignment.type.name, name)) |
|
|
| field = Field(node.pos, default=assignment) |
| else: |
| field = Field(node.pos) |
| field.is_initvar = is_initvar |
| field.is_classvar = is_classvar |
| if entry.visibility == "private": |
| field.private = True |
| fields[name] = field |
| node.entry.type.dataclass_fields = fields |
| return fields |
|
|
|
|
| def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform): |
| |
| kwargs = dict(init=True, repr=True, eq=True, |
| order=False, unsafe_hash=False, |
| frozen=False, kw_only=False, match_args=True) |
| if dataclass_args is not None: |
| if dataclass_args[0]: |
| error(node.pos, "cython.dataclasses.dataclass takes no positional arguments") |
| for k, v in dataclass_args[1].items(): |
| if k not in kwargs: |
| error(node.pos, |
| "cython.dataclasses.dataclass() got an unexpected keyword argument '%s'" % k) |
| if not isinstance(v, ExprNodes.BoolNode): |
| error(node.pos, |
| "Arguments passed to cython.dataclasses.dataclass must be True or False") |
| kwargs[k] = v.value |
|
|
| kw_only = kwargs['kw_only'] |
|
|
| fields = process_class_get_fields(node) |
|
|
| dataclass_module = make_dataclasses_module_callnode(node.pos) |
|
|
| |
| |
| |
| dataclass_params_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module, |
| attribute=EncodedString("_DataclassParams")) |
| dataclass_params_keywords = ExprNodes.DictNode.from_pairs( |
| node.pos, |
| [ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)), |
| ExprNodes.BoolNode(node.pos, value=v)) |
| for k, v in kwargs.items() ] + |
| [ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)), |
| ExprNodes.BoolNode(node.pos, value=v)) |
| for k, v in [('kw_only', kw_only), |
| ('slots', False), ('weakref_slot', False)] |
| ]) |
| dataclass_params = make_dataclass_call_helper( |
| node.pos, dataclass_params_func, dataclass_params_keywords) |
| dataclass_params_assignment = Nodes.SingleAssignmentNode( |
| node.pos, |
| lhs = ExprNodes.NameNode(node.pos, name=EncodedString("__dataclass_params__")), |
| rhs = dataclass_params) |
|
|
| dataclass_fields_stats = _set_up_dataclass_fields(node, fields, dataclass_module) |
|
|
| stats = Nodes.StatListNode(node.pos, |
| stats=[dataclass_params_assignment] + dataclass_fields_stats) |
|
|
| code = TemplateCode() |
| generate_init_code(code, kwargs['init'], node, fields, kw_only) |
| generate_match_args(code, kwargs['match_args'], node, fields, kw_only) |
| generate_repr_code(code, kwargs['repr'], node, fields) |
| generate_eq_code(code, kwargs['eq'], node, fields) |
| generate_order_code(code, kwargs['order'], node, fields) |
| generate_hash_code(code, kwargs['unsafe_hash'], kwargs['eq'], kwargs['frozen'], node, fields) |
|
|
| stats.stats += code.generate_tree().stats |
|
|
| |
| |
| |
| comp_directives = Nodes.CompilerDirectivesNode(node.pos, |
| directives=copy_inherited_directives(node.scope.directives, annotation_typing=False), |
| body=stats) |
|
|
| comp_directives.analyse_declarations(node.scope) |
| |
| analyse_decs_transform.enter_scope(node, node.scope) |
| analyse_decs_transform.visit(comp_directives) |
| analyse_decs_transform.exit_scope() |
|
|
| node.body.stats.append(comp_directives) |
|
|
|
|
| def generate_init_code(code, init, node, fields, kw_only): |
| """ |
| Notes on CPython generated "__init__": |
| * Implemented in `_init_fn`. |
| * The use of the `dataclasses._HAS_DEFAULT_FACTORY` sentinel value as |
| the default argument for fields that need constructing with a factory |
| function is copied from the CPython implementation. (`None` isn't |
| suitable because it could also be a value for the user to pass.) |
| There's no real reason why it needs importing from the dataclasses module |
| though - it could equally be a value generated by Cython when the module loads. |
| * seen_default and the associated error message are copied directly from Python |
| * Call to user-defined __post_init__ function (if it exists) is copied from |
| CPython. |
| |
| Cython behaviour deviates a little here (to be decided if this is right...) |
| Because the class variable from the assignment does not exist Cython fields will |
| return None (or whatever their type default is) if not initialized while Python |
| dataclasses will fall back to looking up the class variable. |
| """ |
| if not init or node.scope.lookup_here("__init__"): |
| return |
|
|
| |
| selfname = "__dataclass_self__" if "self" in fields else "self" |
| args = [selfname] |
|
|
| if kw_only: |
| args.append("*") |
|
|
| function_start_point = code.insertion_point() |
| code = code.insertion_point() |
| code.indent() |
|
|
| |
| dataclass_module = make_dataclasses_module_callnode(node.pos) |
| has_default_factory = ExprNodes.AttributeNode( |
| node.pos, |
| obj=dataclass_module, |
| attribute=EncodedString("_HAS_DEFAULT_FACTORY") |
| ) |
|
|
| default_factory_placeholder = code.new_placeholder(fields, has_default_factory) |
|
|
| seen_default = False |
| for name, field in fields.items(): |
| entry = node.scope.lookup(name) |
| if entry.annotation: |
| annotation = f": {entry.annotation.string.value}" |
| else: |
| annotation = "" |
| assignment = '' |
| if field.default is not MISSING or field.default_factory is not MISSING: |
| if field.init.value: |
| seen_default = True |
| if field.default_factory is not MISSING: |
| ph_name = default_factory_placeholder |
| else: |
| ph_name = code.new_placeholder(fields, field.default) |
| assignment = f" = {ph_name}" |
| elif seen_default and not kw_only and field.init.value: |
| error(entry.pos, ("non-default argument '%s' follows default argument " |
| "in dataclass __init__") % name) |
| code.reset() |
| return |
|
|
| if field.init.value: |
| args.append(f"{name}{annotation}{assignment}") |
|
|
| if field.is_initvar: |
| continue |
| elif field.default_factory is MISSING: |
| if field.init.value: |
| code.add_code_line(f"{selfname}.{name} = {name}") |
| elif assignment: |
| |
| code.add_code_line(f"{selfname}.{name}{assignment}") |
| else: |
| ph_name = code.new_placeholder(fields, field.default_factory) |
| if field.init.value: |
| |
| |
| |
| code.add_code_line( |
| f"{selfname}.{name} = {ph_name}() if {name} is {default_factory_placeholder} else {name}" |
| ) |
| else: |
| |
| code.add_code_line(f"{selfname}.{name} = {ph_name}()") |
|
|
| if node.scope.lookup("__post_init__"): |
| post_init_vars = ", ".join(name for name, field in fields.items() |
| if field.is_initvar) |
| code.add_code_line(f"{selfname}.__post_init__({post_init_vars})") |
|
|
| if code.empty(): |
| code.add_code_line("pass") |
|
|
| args = ", ".join(args) |
| function_start_point.add_code_line(f"def __init__({args}):") |
|
|
|
|
| def generate_match_args(code, match_args, node, fields, global_kw_only): |
| """ |
| Generates a tuple containing what would be the positional args to __init__ |
| |
| Note that this is generated even if the user overrides init |
| """ |
| if not match_args or node.scope.lookup_here("__match_args__"): |
| return |
| positional_arg_names = [] |
| for field_name, field in fields.items(): |
| |
| field_is_kw_only = global_kw_only or ( |
| hasattr(field, 'kw_only') and field.kw_only.value |
| ) |
| if not field_is_kw_only: |
| positional_arg_names.append(field_name) |
| code.add_code_line("__match_args__ = %s" % str(tuple(positional_arg_names))) |
|
|
|
|
| def generate_repr_code(code, repr, node, fields): |
| """ |
| The core of the CPython implementation is just: |
| ['return self.__class__.__qualname__ + f"(' + |
| ', '.join([f"{f.name}={{self.{f.name}!r}}" |
| for f in fields]) + |
| ')"'], |
| |
| The only notable difference here is self.__class__.__qualname__ -> type(self).__name__ |
| which is because Cython currently supports Python 2. |
| |
| However, it also has some guards for recursive repr invocations. In the standard |
| library implementation they're done with a wrapper decorator that captures a set |
| (with the set keyed by id and thread). Here we create a set as a thread local |
| variable and key only by id. |
| """ |
| if not repr or node.scope.lookup("__repr__"): |
| return |
|
|
| |
| |
| needs_recursive_guard = False |
| for name in fields.keys(): |
| entry = node.scope.lookup(name) |
| type_ = entry.type |
| if type_.is_memoryviewslice: |
| type_ = type_.dtype |
| if not type_.is_pyobject: |
| continue |
| if not type_.is_gc_simple: |
| needs_recursive_guard = True |
| break |
|
|
| if needs_recursive_guard: |
| code.add_code_chunk(""" |
| __pyx_recursive_repr_guard = __import__('threading').local() |
| __pyx_recursive_repr_guard.running = set() |
| """) |
|
|
| with code.indenter("def __repr__(self):"): |
| if needs_recursive_guard: |
| code.add_code_chunk(""" |
| key = id(self) |
| guard_set = self.__pyx_recursive_repr_guard.running |
| if key in guard_set: return '...' |
| guard_set.add(key) |
| try: |
| """) |
| code.indent() |
|
|
| strs = ["%s={self.%s!r}" % (name, name) |
| for name, field in fields.items() |
| if field.repr.value and not field.is_initvar] |
| format_string = ", ".join(strs) |
|
|
| code.add_code_chunk(f''' |
| name = getattr(type(self), "__qualname__", None) or type(self).__name__ |
| return f'{{name}}({format_string})' |
| ''') |
| if needs_recursive_guard: |
| code.dedent() |
| with code.indenter("finally:"): |
| code.add_code_line("guard_set.remove(key)") |
|
|
|
|
| def generate_cmp_code(code, op, funcname, node, fields): |
| if node.scope.lookup_here(funcname): |
| return |
|
|
| names = [name for name, field in fields.items() if (field.compare.value and not field.is_initvar)] |
|
|
| with code.indenter(f"def {funcname}(self, other):"): |
| code.add_code_chunk(f""" |
| if other.__class__ is not self.__class__: return NotImplemented |
| |
| cdef {node.class_name} other_cast |
| other_cast = <{node.class_name}>other |
| """) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| checks = [] |
| op_without_equals = op.replace('=', '') |
|
|
| for name in names: |
| if op != '==': |
| |
| code.add_code_line(f"if self.{name} {op_without_equals} other_cast.{name}: return True") |
| code.add_code_line(f"if self.{name} != other_cast.{name}: return False") |
| code.add_code_line(f"return {'True' if '=' in op else 'False'}") |
|
|
|
|
| def generate_eq_code(code, eq, node, fields): |
| if not eq: |
| return |
| generate_cmp_code(code, "==", "__eq__", node, fields) |
|
|
|
|
| def generate_order_code(code, order, node, fields): |
| if not order: |
| return |
|
|
| for op, name in [("<", "__lt__"), |
| ("<=", "__le__"), |
| (">", "__gt__"), |
| (">=", "__ge__")]: |
| generate_cmp_code(code, op, name, node, fields) |
|
|
|
|
| def generate_hash_code(code, unsafe_hash, eq, frozen, node, fields): |
| """ |
| Copied from CPython implementation - the intention is to follow this as far as |
| is possible: |
| # +------------------- unsafe_hash= parameter |
| # | +----------- eq= parameter |
| # | | +--- frozen= parameter |
| # | | | |
| # v v v | | | |
| # | no | yes | <--- class has explicitly defined __hash__ |
| # +=======+=======+=======+========+========+ |
| # | False | False | False | | | No __eq__, use the base class __hash__ |
| # +-------+-------+-------+--------+--------+ |
| # | False | False | True | | | No __eq__, use the base class __hash__ |
| # +-------+-------+-------+--------+--------+ |
| # | False | True | False | None | | <-- the default, not hashable |
| # +-------+-------+-------+--------+--------+ |
| # | False | True | True | add | | Frozen, so hashable, allows override |
| # +-------+-------+-------+--------+--------+ |
| # | True | False | False | add | raise | Has no __eq__, but hashable |
| # +-------+-------+-------+--------+--------+ |
| # | True | False | True | add | raise | Has no __eq__, but hashable |
| # +-------+-------+-------+--------+--------+ |
| # | True | True | False | add | raise | Not frozen, but hashable |
| # +-------+-------+-------+--------+--------+ |
| # | True | True | True | add | raise | Frozen, so hashable |
| # +=======+=======+=======+========+========+ |
| # For boxes that are blank, __hash__ is untouched and therefore |
| # inherited from the base class. If the base is object, then |
| # id-based hashing is used. |
| |
| The Python implementation creates a tuple of all the fields, then hashes them. |
| This implementation creates a tuple of all the hashes of all the fields and hashes that. |
| The reason for this slight difference is to avoid to-Python conversions for anything |
| that Cython knows how to hash directly (It doesn't look like this currently applies to |
| anything though...). |
| """ |
|
|
| hash_entry = node.scope.lookup_here("__hash__") |
| if hash_entry: |
| |
| |
| if unsafe_hash: |
| |
| error(node.pos, "Cannot overwrite attribute __hash__ in class %s" % node.class_name) |
| return |
|
|
| if not unsafe_hash: |
| if not eq: |
| return |
| if not frozen: |
| code.add_extra_statements([ |
| Nodes.SingleAssignmentNode( |
| node.pos, |
| lhs=ExprNodes.NameNode(node.pos, name=EncodedString("__hash__")), |
| rhs=ExprNodes.NoneNode(node.pos), |
| ) |
| ]) |
| return |
|
|
| names = [ |
| name for name, field in fields.items() |
| if not field.is_initvar and ( |
| field.compare.value if field.hash.value is None else field.hash.value) |
| ] |
|
|
| |
| hash_tuple_items = ", ".join("self.%s" % name for name in names) |
| if hash_tuple_items: |
| hash_tuple_items += "," |
|
|
| |
| with code.indenter("def __hash__(self):"): |
| code.add_code_line(f"return hash(({hash_tuple_items}))") |
|
|
|
|
| def get_field_type(pos, entry): |
| """ |
| sets the .type attribute for a field |
| |
| Returns the annotation if possible (since this is what the dataclasses |
| module does). If not (for example, attributes defined with cdef) then |
| it creates a string fallback. |
| """ |
| if entry.annotation: |
| |
| |
| |
| |
| return entry.annotation.string |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| else: |
| |
| |
| |
| |
| s = EncodedString(entry.type.declaration_code("", for_display=1)) |
| return ExprNodes.UnicodeNode(pos, value=s) |
|
|
|
|
| class FieldRecordNode(ExprNodes.ExprNode): |
| """ |
| __dataclass_fields__ contains a bunch of field objects recording how each field |
| of the dataclass was initialized (mainly corresponding to the arguments passed to |
| the "field" function). This node is used for the attributes of these field objects. |
| |
| If possible, coerces `arg` to a Python object. |
| Otherwise, generates a sensible backup string. |
| """ |
| subexprs = ['arg'] |
|
|
| def __init__(self, pos, arg): |
| super().__init__(pos, arg=arg) |
|
|
| def analyse_types(self, env): |
| self.arg.analyse_types(env) |
| self.type = self.arg.type |
| return self |
|
|
| def coerce_to_pyobject(self, env): |
| if self.arg.type.can_coerce_to_pyobject(env): |
| return self.arg.coerce_to_pyobject(env) |
| else: |
| |
| |
| |
| return self._make_string() |
|
|
| def _make_string(self): |
| from .AutoDocTransforms import AnnotationWriter |
| writer = AnnotationWriter(description="Dataclass field") |
| string = writer.write(self.arg) |
| return ExprNodes.UnicodeNode(self.pos, value=EncodedString(string)) |
|
|
| def generate_evaluation_code(self, code): |
| return self.arg.generate_evaluation_code(code) |
|
|
|
|
| def _set_up_dataclass_fields(node, fields, dataclass_module): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| variables_assignment_stats = [] |
| for name, field in fields.items(): |
| if field.private: |
| continue |
| for attrname in [ "default", "default_factory" ]: |
| field_default = getattr(field, attrname) |
| if field_default is MISSING or field_default.is_literal or field_default.is_name: |
| |
| |
| continue |
| global_scope = node.scope.global_scope() |
| module_field_name = global_scope.mangle( |
| global_scope.mangle(Naming.dataclass_field_default_cname, node.class_name), |
| name) |
| |
| field_node = ExprNodes.NameNode(field_default.pos, name=EncodedString(module_field_name)) |
| field_node.entry = global_scope.declare_var( |
| field_node.name, type=field_default.type or PyrexTypes.unspecified_type, |
| pos=field_default.pos, cname=field_node.name, is_cdef=True, |
| |
| ) |
| |
| setattr(field, attrname, field_node) |
|
|
| variables_assignment_stats.append( |
| Nodes.SingleAssignmentNode(field_default.pos, lhs=field_node, rhs=field_default)) |
|
|
| placeholders = {} |
| field_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module, |
| attribute=EncodedString("field")) |
| dc_fields = ExprNodes.DictNode(node.pos, key_value_pairs=[]) |
| dc_fields_namevalue_assignments = [] |
|
|
| for name, field in fields.items(): |
| if field.private: |
| continue |
| type_placeholder_name = "PLACEHOLDER_%s" % name |
| placeholders[type_placeholder_name] = get_field_type( |
| node.pos, node.scope.entries[name] |
| ) |
|
|
| |
| field_type_placeholder_name = "PLACEHOLDER_FIELD_TYPE_%s" % name |
| if field.is_initvar: |
| placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode( |
| node.pos, obj=dataclass_module, |
| attribute=EncodedString("_FIELD_INITVAR") |
| ) |
| elif field.is_classvar: |
| |
| placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode( |
| node.pos, obj=dataclass_module, |
| attribute=EncodedString("_FIELD_CLASSVAR") |
| ) |
| else: |
| placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode( |
| node.pos, obj=dataclass_module, |
| attribute=EncodedString("_FIELD") |
| ) |
|
|
| dc_field_keywords = ExprNodes.DictNode.from_pairs( |
| node.pos, |
| [(ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)), |
| FieldRecordNode(node.pos, arg=v)) |
| for k, v in field.iterate_record_node_arguments()] |
|
|
| ) |
| dc_field_call = make_dataclass_call_helper( |
| node.pos, field_func, dc_field_keywords |
| ) |
| dc_fields.key_value_pairs.append( |
| ExprNodes.DictItemNode( |
| node.pos, |
| key=ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(name)), |
| value=dc_field_call)) |
| dc_fields_namevalue_assignments.append( |
| dedent(f"""\ |
| __dataclass_fields__[{name!r}].name = {name!r} |
| __dataclass_fields__[{name!r}].type = {type_placeholder_name} |
| __dataclass_fields__[{name!r}]._field_type = {field_type_placeholder_name} |
| """)) |
|
|
| dataclass_fields_assignment = \ |
| Nodes.SingleAssignmentNode(node.pos, |
| lhs = ExprNodes.NameNode(node.pos, |
| name=EncodedString("__dataclass_fields__")), |
| rhs = dc_fields) |
|
|
| dc_fields_namevalue_assignments = "\n".join(dc_fields_namevalue_assignments) |
| dc_fields_namevalue_assignments = TreeFragment(dc_fields_namevalue_assignments, |
| level="c_class", |
| pipeline=[NormalizeTree(None)]) |
| dc_fields_namevalue_assignments = dc_fields_namevalue_assignments.substitute(placeholders) |
|
|
| return (variables_assignment_stats |
| + [dataclass_fields_assignment] |
| + dc_fields_namevalue_assignments.stats) |
|
|