import inspect import textwrap from keras.src import backend from keras.src import dtype_policies from keras.src import tree from keras.src.api_export import keras_export from keras.src.backend.common.keras_tensor import any_symbolic_tensors from keras.src.ops.node import Node from keras.src.utils import python_utils from keras.src.utils import traceback_utils from keras.src.utils.naming import auto_name @keras_export("keras.Operation") class Operation: def __init__(self, dtype=None, name=None): if name is None: name = auto_name(self.__class__.__name__) if not isinstance(name, str) or "/" in name: raise ValueError( "Argument `name` must be a string and " "cannot contain character `/`. " f"Received: name={name} (of type {type(name)})" ) self._dtype_policy = dtype_policies.get(dtype) self.name = name self._inbound_nodes = [] self._outbound_nodes = [] @traceback_utils.filter_traceback def __call__(self, *args, **kwargs): if traceback_utils.is_traceback_filtering_enabled(): # Wrap self.call to provide helpful info in case of exception if any_symbolic_tensors(args, kwargs): call_fn = self.symbolic_call else: if getattr(self, "_remat_mode", None) is not None: if getattr(self, "quantization_mode", None) is not None: call_fn = self.rematerialized_call( self.quantized_call, *args, **kwargs, ) else: call_fn = self.rematerialized_call( self.call, *args, **kwargs ) else: if getattr(self, "quantization_mode", None) is not None: call_fn = self.quantized_call else: call_fn = self.call call_fn = traceback_utils.inject_argument_info_in_traceback( call_fn, object_name=(f"{self.__class__.__name__}.call()"), ) return call_fn(*args, **kwargs) # Plain flow. if any_symbolic_tensors(args, kwargs): return self.symbolic_call(*args, **kwargs) elif getattr(self, "_remat_mode", None) is not None: if getattr(self, "quantization_mode", None) is not None: return self.rematerialized_call( self.quantized_call, *args, **kwargs )(*args, **kwargs) else: return self.rematerialized_call(self.call, *args, **kwargs)( *args, **kwargs ) else: if getattr(self, "quantization_mode", None) is not None: return self.quantized_call(*args, **kwargs) else: return self.call(*args, **kwargs) def symbolic_call(self, *args, **kwargs): # Perform shape/dtype inference. outputs = self.compute_output_spec(*args, **kwargs) # Record a new node in the operations graph. # The Node wires itself to inbound and outbound ops. The # Node constructor updates this op's self._inbound_nodes, # sets _keras_history on the outputs, and adds itself to the # `_outbound_nodes` of the ops that produced the inputs to this # call. Node( operation=self, call_args=args, call_kwargs=kwargs, outputs=outputs ) return outputs def call(self, *args, **kwargs): raise NotImplementedError def quantized_call(self, *args, **kwargs): raise NotImplementedError def compute_output_spec(self, *args, **kwargs): try: return backend.compute_output_spec(self.call, *args, **kwargs) except Exception as e: new_e = e.__class__( "Could not automatically infer the output shape / dtype of " f"'{self.name}' (of type {self.__class__.__name__}). " f"Either the `{self.__class__.__name__}.call()` method " f"is incorrect, or you need to implement the " f"`{self.__class__.__name__}.compute_output_spec() / " "compute_output_shape()` method. " f"Error encountered:\n\n{e}" ) raise new_e.with_traceback(e.__traceback__) from None def __new__(cls, *args, **kwargs): """We override __new__ to saving serializable constructor arguments. These arguments are used to auto-generate an object serialization config, which enables user-created subclasses to be serializable out of the box in most cases without forcing the user to manually implement `get_config()`. """ instance = super(Operation, cls).__new__(cls) # Generate a config to be returned by default by `get_config()`. arg_names = inspect.getfullargspec(cls.__init__).args kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) # Explicitly serialize `dtype` to support auto_config dtype = kwargs.get("dtype", None) if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy): # For backward compatibility, we use a str (`name`) for # `DTypePolicy` if dtype.quantization_mode is None: kwargs["dtype"] = dtype.name # Otherwise, use `dtype_policies.serialize` else: kwargs["dtype"] = dtype_policies.serialize(dtype) # For safety, we only rely on auto-configs for a small set of # serializable types. supported_types = (str, int, float, bool, type(None)) try: flat_arg_values = tree.flatten(kwargs) auto_config = True for value in flat_arg_values: if not isinstance(value, supported_types): auto_config = False break except TypeError: auto_config = False try: instance._lock = False if auto_config: from keras.src.saving import serialization_lib instance._auto_config = serialization_lib.SerializableDict( **kwargs ) else: instance._auto_config = None instance._lock = True except RecursionError: # Setting an instance attribute in __new__ has the potential # to trigger an infinite recursion if a subclass overrides # setattr in an unsafe way. pass return instance @python_utils.default def get_config(self): """Returns the config of the object. An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it. """ config = { "name": self.name, } if not python_utils.is_default(self.get_config): # In this case the subclass implements get_config() return config # In this case the subclass doesn't implement get_config(): # Let's see if we can autogenerate it. if getattr(self, "_auto_config", None) is not None: xtra_args = set(config.keys()) config.update(self._auto_config.config) # Remove args non explicitly supported argspec = inspect.getfullargspec(self.__init__) if argspec.varkw != "kwargs": for key in xtra_args - xtra_args.intersection(argspec.args[1:]): config.pop(key, None) return config else: raise NotImplementedError( textwrap.dedent( f""" Object {self.__class__.__name__} was created by passing non-serializable argument values in `__init__()`, and therefore the object must override `get_config()` in order to be serializable. Please implement `get_config()`. Example: class CustomLayer(keras.layers.Layer): def __init__(self, arg1, arg2, **kwargs): super().__init__(**kwargs) self.arg1 = arg1 self.arg2 = arg2 def get_config(self): config = super().get_config() config.update({{ "arg1": self.arg1, "arg2": self.arg2, }}) return config""" ) ) @classmethod def from_config(cls, config): """Creates an operation from its config. This method is the reverse of `get_config`, capable of instantiating the same operation from the config dictionary. Note: If you override this method, you might receive a serialized dtype config, which is a `dict`. You can deserialize it as follows: ```python if "dtype" in config and isinstance(config["dtype"], dict): policy = dtype_policies.deserialize(config["dtype"]) ``` Args: config: A Python dictionary, typically the output of `get_config`. Returns: An operation instance. """ # Explicitly deserialize dtype config if needed. This enables users to # directly interact with the instance of `DTypePolicy`. if "dtype" in config and isinstance(config["dtype"], dict): config = config.copy() policy = dtype_policies.deserialize(config["dtype"]) if ( not isinstance(policy, dtype_policies.DTypePolicyMap) and policy.quantization_mode is None ): # For backward compatibility, we use a str (`name`) for # `DTypePolicy` policy = policy.name config["dtype"] = policy try: return cls(**config) except Exception as e: raise TypeError( f"Error when deserializing class '{cls.__name__}' using " f"config={config}.\n\nException encountered: {e}" ) def __repr__(self): return f"" @property def input(self): """Retrieves the input tensor(s) of a symbolic operation. Only returns the tensor(s) corresponding to the *first time* the operation was called. Returns: Input tensor or list of input tensors. """ return self._get_node_attribute_at_index(0, "input_tensors", "input") @property def output(self): """Retrieves the output tensor(s) of a layer. Only returns the tensor(s) corresponding to the *first time* the operation was called. Returns: Output tensor or list of output tensors. """ return self._get_node_attribute_at_index(0, "output_tensors", "output") def _get_node_attribute_at_index(self, node_index, attr, attr_name): """Private utility to retrieves an attribute (e.g. inputs) from a node. This is used to implement the properties: - output - input Args: node_index: Integer index of the node from which to retrieve the attribute. attr: Exact node attribute name. attr_name: Human-readable attribute name, for error messages. Returns: The operation's attribute `attr` at the node of index `node_index`. """ if not self._inbound_nodes: raise AttributeError( f"The layer {self.name} has never been called " f"and thus has no defined {attr_name}." ) if not len(self._inbound_nodes) > node_index: raise ValueError( f"Asked to get {attr_name} at node " f"{node_index}, but the operation has only " f"{len(self._inbound_nodes)} inbound nodes." ) values = getattr(self._inbound_nodes[node_index], attr) if isinstance(values, list) and len(values) == 1: return values[0] else: return values # Hooks for backend layer classes def _post_build(self): """Can be overridden for per backend post build actions.""" pass def _setattr_hook(self, name, value): """Can be overridden for per backend post build actions.""" return name, value def _post_track_variable(self, variable): """Can be overridden for per backend post track actions.""" pass def _post_untrack_variable(self, variable): """Can be overridden for per backend post untrack actions.""" pass