|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""This module contains the user- and codegen-facing API for DiastaticMalt.""" |
|
|
|
|
|
import functools |
|
|
import importlib |
|
|
import inspect |
|
|
import os |
|
|
import sys |
|
|
import textwrap |
|
|
import traceback |
|
|
|
|
|
from malt import operators |
|
|
from malt import utils |
|
|
from malt.converters import asserts |
|
|
from malt.converters import break_statements |
|
|
from malt.converters import call_trees |
|
|
from malt.converters import conditional_expressions |
|
|
from malt.converters import continue_statements |
|
|
from malt.converters import control_flow |
|
|
from malt.converters import directives |
|
|
from malt.converters import functions |
|
|
from malt.converters import lists |
|
|
from malt.converters import logical_expressions |
|
|
from malt.converters import return_statements |
|
|
from malt.converters import slices |
|
|
from malt.converters import variables |
|
|
from malt.core import ag_ctx |
|
|
from malt.core import converter |
|
|
from malt.core import unsupported_features_checker |
|
|
from malt.impl import conversion |
|
|
from malt.lang import special_functions |
|
|
from malt.operators import py_builtins |
|
|
from malt.pyct import anno |
|
|
from malt.pyct import cfg |
|
|
from malt.pyct import error_utils |
|
|
from malt.pyct import errors |
|
|
from malt.pyct import inspect_utils |
|
|
from malt.pyct import qual_names |
|
|
from malt.pyct import transpiler |
|
|
from malt.pyct.static_analysis import activity |
|
|
from malt.pyct.static_analysis import reaching_definitions |
|
|
from malt.utils import ag_logging as logging |
|
|
|
|
|
|
|
|
def is_autograph_strict_conversion_mode(): |
|
|
return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AutoGraphError(errors.PyCTError): |
|
|
"""Base class for all AutoGraph exceptions.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class ConversionError(AutoGraphError): |
|
|
"""Raised during the conversion process.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class StagingError(AutoGraphError): |
|
|
"""Raised during the staging (i.e. Python execution) of converted code.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class _ErrorMetadata(error_utils.ErrorMetadataBase): |
|
|
"""AutoGraph-specific error metadata. See base class.""" |
|
|
|
|
|
def create_exception(self, source_error): |
|
|
preferred_type = type(source_error) |
|
|
|
|
|
if preferred_type in (errors.PyCTError, AutoGraphError, ConversionError, StagingError): |
|
|
return preferred_type(self.get_message()) |
|
|
|
|
|
exc = super(_ErrorMetadata, self).create_exception(source_error) |
|
|
if exc is not None: |
|
|
return exc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return StagingError(self.get_message()) |
|
|
|
|
|
|
|
|
def _attach_error_metadata(e, f): |
|
|
"""Augments an error with the metadata necessary for rewrite.""" |
|
|
if hasattr(e, 'ag_pass_through'): |
|
|
return |
|
|
|
|
|
metadata = getattr(e, 'ag_error_metadata', None) |
|
|
source_map = f.ag_source_map |
|
|
|
|
|
if metadata is None: |
|
|
logging.log(1, 'Caught error in user callable %s', f, exc_info=True) |
|
|
message = '{}: {}'.format(e.__class__.__name__, e) |
|
|
else: |
|
|
message = None |
|
|
|
|
|
cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:] |
|
|
|
|
|
e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map, |
|
|
__file__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PyToPy(transpiler.PyToPy): |
|
|
"""A generic AutoGraph transformer to subclass from or replace.""" |
|
|
|
|
|
def __init__(self): |
|
|
super(PyToPy, self).__init__() |
|
|
self._extra_locals = None |
|
|
|
|
|
def get_transformed_name(self, node): |
|
|
return 'ag__' + super(PyToPy, self).get_transformed_name(node) |
|
|
|
|
|
def get_extra_locals(self): |
|
|
if self._extra_locals is None: |
|
|
|
|
|
|
|
|
|
|
|
module_spec = importlib.machinery.ModuleSpec('malt', None) |
|
|
ag_internal = importlib.util.module_from_spec(module_spec) |
|
|
ag_internal.__dict__.update(inspect.getmodule(PyToPy).__dict__) |
|
|
ag_internal.ConversionOptions = converter.ConversionOptions |
|
|
ag_internal.STD = converter.STANDARD_OPTIONS |
|
|
ag_internal.Feature = converter.Feature |
|
|
ag_internal.utils = utils |
|
|
|
|
|
|
|
|
|
|
|
ag_internal.__dict__.update(special_functions.__dict__) |
|
|
ag_internal.__dict__.update(operators.__dict__) |
|
|
|
|
|
self._extra_locals = {'ag__': ag_internal} |
|
|
return self._extra_locals |
|
|
|
|
|
def get_caching_key(self, ctx): |
|
|
return ctx.options |
|
|
|
|
|
def initial_analysis(self, node, ctx): |
|
|
graphs = cfg.build(node) |
|
|
node = qual_names.resolve(node) |
|
|
node = activity.resolve(node, ctx, None) |
|
|
node = reaching_definitions.resolve(node, ctx, graphs) |
|
|
anno.dup( |
|
|
node, |
|
|
{ |
|
|
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, |
|
|
}, |
|
|
) |
|
|
return node |
|
|
|
|
|
def transform_ast(self, node, ctx): |
|
|
unsupported_features_checker.verify(node) |
|
|
node = self.initial_analysis(node, ctx) |
|
|
|
|
|
node = functions.transform(node, ctx) |
|
|
node = directives.transform(node, ctx) |
|
|
node = break_statements.transform(node, ctx) |
|
|
if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): |
|
|
node = asserts.transform(node, ctx) |
|
|
|
|
|
|
|
|
|
|
|
node = continue_statements.transform(node, ctx) |
|
|
node = return_statements.transform(node, ctx) |
|
|
if ctx.user.options.uses(converter.Feature.LISTS): |
|
|
node = lists.transform(node, ctx) |
|
|
node = slices.transform(node, ctx) |
|
|
node = call_trees.transform(node, ctx) |
|
|
node = control_flow.transform(node, ctx) |
|
|
node = conditional_expressions.transform(node, ctx) |
|
|
node = logical_expressions.transform(node, ctx) |
|
|
node = variables.transform(node, ctx) |
|
|
return node |
|
|
|
|
|
|
|
|
def _convert_actual(entity, program_ctx): |
|
|
"""Applies AutoGraph to entity.""" |
|
|
|
|
|
|
|
|
if not hasattr(entity, '__code__'): |
|
|
raise ValueError('Cannot apply autograph to a function that doesn\'t ' |
|
|
'expose a __code__ object. If this is a @tf.function,' |
|
|
' try passing f.python_function instead.') |
|
|
|
|
|
transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx) |
|
|
|
|
|
assert not hasattr(transformed, 'ag_module') |
|
|
assert not hasattr(transformed, 'ag_source_map') |
|
|
transformed.ag_module = module |
|
|
transformed.ag_source_map = source_map |
|
|
return transformed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def autograph_artifact(entity, extras=None): |
|
|
if inspect.ismethod(entity): |
|
|
setattr(entity.__func__, 'autograph_info__', extras) |
|
|
else: |
|
|
setattr(entity, 'autograph_info__', extras) |
|
|
return entity |
|
|
|
|
|
|
|
|
def is_autograph_artifact(entity): |
|
|
return hasattr(entity, 'autograph_info__') |
|
|
|
|
|
|
|
|
def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): |
|
|
"""Converts a function call inline. |
|
|
|
|
|
For internal use only. |
|
|
|
|
|
Note: The argument list is optimized for readability of generated code, which |
|
|
may look like this: |
|
|
|
|
|
ag__.converted_call(f, (arg1, arg2), None, fscope) |
|
|
ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope) |
|
|
ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope) |
|
|
|
|
|
Args: |
|
|
f: The function to convert. |
|
|
args: Tuple, the original positional arguments of f |
|
|
kwargs: Optional[Dict], the original keyword arguments of f |
|
|
caller_fn_scope: Optional[function_wrappers.FunctionScope], the function |
|
|
scope of the converted function in which this call was originally made. |
|
|
options: Optional[converter.ConversionOptions], conversion options. If not |
|
|
specified, the value of caller_fn_scope.callopts is used. Either options |
|
|
or caller_fn_scope must be present. |
|
|
|
|
|
Returns: |
|
|
Any, the result of executing a possibly-converted `f` with the given |
|
|
arguments. |
|
|
""" |
|
|
logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, |
|
|
kwargs) |
|
|
|
|
|
if options is None: |
|
|
if caller_fn_scope is None: |
|
|
raise ValueError('either caller_fn_scope or options must have a value') |
|
|
options = caller_fn_scope.callopts |
|
|
|
|
|
if conversion.is_in_allowlist_cache(f, options): |
|
|
logging.log(2, 'Allowlisted %s: from cache', f) |
|
|
return _call_unconverted(f, args, kwargs, options, False) |
|
|
|
|
|
if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: |
|
|
logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f) |
|
|
return _call_unconverted(f, args, kwargs, options, False) |
|
|
|
|
|
if is_autograph_artifact(f): |
|
|
logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f) |
|
|
return _call_unconverted(f, args, kwargs, options) |
|
|
|
|
|
|
|
|
if isinstance(f, functools.partial): |
|
|
new_kwargs = {} |
|
|
if f.keywords is not None: |
|
|
|
|
|
new_kwargs = f.keywords.copy() |
|
|
if kwargs is not None: |
|
|
new_kwargs.update(kwargs) |
|
|
new_args = f.args + args |
|
|
logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args, |
|
|
new_kwargs) |
|
|
return converted_call( |
|
|
f.func, |
|
|
new_args, |
|
|
new_kwargs, |
|
|
caller_fn_scope=caller_fn_scope, |
|
|
options=options) |
|
|
|
|
|
if inspect_utils.isbuiltin(f): |
|
|
if f is eval: |
|
|
return py_builtins.eval_in_original_context(f, args, caller_fn_scope) |
|
|
if f is super: |
|
|
return py_builtins.super_in_original_context(f, args, caller_fn_scope) |
|
|
if f is globals: |
|
|
return py_builtins.globals_in_original_context(caller_fn_scope) |
|
|
if f is locals: |
|
|
return py_builtins.locals_in_original_context(caller_fn_scope) |
|
|
if kwargs: |
|
|
return py_builtins.overload_of(f)(*args, **kwargs) |
|
|
else: |
|
|
return py_builtins.overload_of(f)(*args) |
|
|
|
|
|
if conversion.is_unsupported(f): |
|
|
return _call_unconverted(f, args, kwargs, options) |
|
|
|
|
|
if not options.user_requested and conversion.is_allowlisted(f): |
|
|
return _call_unconverted(f, args, kwargs, options) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not options.internal_convert_user_code: |
|
|
return _call_unconverted(f, args, kwargs, options) |
|
|
|
|
|
try: |
|
|
if inspect.ismethod(f) or inspect.isfunction(f): |
|
|
target_entity = f |
|
|
effective_args = args |
|
|
|
|
|
f_self = getattr(f, '__self__', None) |
|
|
if f_self is not None: |
|
|
|
|
|
effective_args = (f_self,) + effective_args |
|
|
|
|
|
elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_entity = f.__class__.__call__ |
|
|
effective_args = (f,) + args |
|
|
|
|
|
else: |
|
|
target_entity = f |
|
|
raise NotImplementedError('unknown callable type "%s"' % type(f)) |
|
|
|
|
|
except Exception as e: |
|
|
logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) |
|
|
if is_autograph_strict_conversion_mode(): |
|
|
raise |
|
|
return _fall_back_unconverted(f, args, kwargs, options, e) |
|
|
|
|
|
if not hasattr(target_entity, '__code__'): |
|
|
logging.log(2, 'Permanently allowed: %s: native binding', target_entity) |
|
|
return _call_unconverted(f, args, kwargs, options) |
|
|
elif (hasattr(target_entity.__code__, 'co_filename') and |
|
|
target_entity.__code__.co_filename == '<string>'): |
|
|
|
|
|
logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)', |
|
|
target_entity) |
|
|
return _call_unconverted(f, args, kwargs, options) |
|
|
|
|
|
try: |
|
|
program_ctx = converter.ProgramContext(options=options) |
|
|
converted_f = _convert_actual(target_entity, program_ctx) |
|
|
if logging.has_verbosity(2): |
|
|
_log_callargs(converted_f, effective_args, kwargs) |
|
|
except Exception as e: |
|
|
logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) |
|
|
if is_autograph_strict_conversion_mode(): |
|
|
raise |
|
|
return _fall_back_unconverted(f, args, kwargs, options, e) |
|
|
|
|
|
|
|
|
try: |
|
|
if kwargs is not None: |
|
|
result = converted_f(*effective_args, **kwargs) |
|
|
else: |
|
|
result = converted_f(*effective_args) |
|
|
except Exception as e: |
|
|
_attach_error_metadata(e, converted_f) |
|
|
raise |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def _call_unconverted(f, args, kwargs, options, update_cache=True): |
|
|
"""Calls the original function without converting with AutoGraph.""" |
|
|
if update_cache: |
|
|
conversion.cache_allowlisted(f, options) |
|
|
|
|
|
|
|
|
|
|
|
if kwargs is not None: |
|
|
return f(*args, **kwargs) |
|
|
return f(*args) |
|
|
|
|
|
|
|
|
def _fall_back_unconverted(f, args, kwargs, options, exc): |
|
|
"""Falls back to calling the function unconverted, in case of error.""" |
|
|
|
|
|
warning_template = ( |
|
|
'AutoGraph could not transform %s and will run it as-is.\n' |
|
|
'%s' |
|
|
'Cause: %s\n' |
|
|
'To silence this warning, decorate the function with' |
|
|
' @tf.autograph.experimental.do_not_convert') |
|
|
if isinstance(exc, errors.InaccessibleSourceCodeError): |
|
|
if ag_ctx.INSPECT_SOURCE_SUPPORTED: |
|
|
logging.warning(warning_template, f, '', exc) |
|
|
elif isinstance(exc, errors.UnsupportedLanguageElementError): |
|
|
if not conversion.is_in_allowlist_cache(f, options): |
|
|
logging.warning(warning_template, f, '', exc) |
|
|
else: |
|
|
file_bug_message = ( |
|
|
'Please report this to the TensorFlow team. When filing the bug, set' |
|
|
' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' |
|
|
' attach the full output.\n') |
|
|
logging.warning(warning_template, f, file_bug_message, exc) |
|
|
|
|
|
return _call_unconverted(f, args, kwargs, options) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def internal_convert(f, ctx, convert_by_default=True, user_requested=False): |
|
|
"""Decorator that applies AutoGraph to a function. |
|
|
|
|
|
Use in internal APIs. |
|
|
|
|
|
This API is suitable for high order functions internal to the TensorFlow API, |
|
|
and more generally any function to which AutoGraph is not applied. |
|
|
|
|
|
Guidance: `convert` was a decorator meant for use directly by developers, but |
|
|
most of today's uses go through `tf.function`. `tf_convert` is to be called |
|
|
from high order functions internal to TF. By default, all the internal |
|
|
TensorFlow functions are skipped when AutoGraph processes the code. This may |
|
|
lead to user-supplied functions to be incorrectly skipped as well. |
|
|
`tf_convert` helps avoid that. See the following example for more details. |
|
|
|
|
|
``` |
|
|
=====tf_internal_module.py===== |
|
|
|
|
|
def unconverted(input_fn): |
|
|
return input_fn() |
|
|
|
|
|
def converted(input_fn): |
|
|
return tf.__internal__.autograph.tf_convert( |
|
|
input_fn, ctx=tf.__internal__.autograph.control_status_ctx())() |
|
|
|
|
|
======user_module.py====== |
|
|
|
|
|
@tf.function |
|
|
def foo(input_fn) |
|
|
return unconverted(input_fn) |
|
|
|
|
|
@tf.function |
|
|
def bar(input_fn) |
|
|
return converted(input_fn) |
|
|
|
|
|
@tf.function(autograph=False) |
|
|
def baz(input_fn) |
|
|
return converted(input_fn) |
|
|
``` |
|
|
|
|
|
The `foo` method above will execute the `input_fn` without autograph |
|
|
conversion, while the `bar` method will run an autographed `input_fn`. The |
|
|
`baz` method will run an unconverted `input_fn`, since `tf_convert` respect |
|
|
the control status context. |
|
|
|
|
|
Note that both methods in `tf_internal_module` are skipped by autograph when |
|
|
tracing the `tf.function`. The configuration of whether a module/package |
|
|
should be skipped by autograph is controlled in |
|
|
tensorflow/python/autograph/core/config.py. |
|
|
|
|
|
Args: |
|
|
f: Callable. |
|
|
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used. |
|
|
convert_by_default: bool, whether to use AutoGraph when the context doesn't |
|
|
specify. |
|
|
user_requested: bool, whether to ignore the conversion allowlist. See |
|
|
ConversionOptions.user_requested. |
|
|
|
|
|
Returns: |
|
|
Either `f or the converted version of `f`. |
|
|
""" |
|
|
|
|
|
if is_autograph_artifact(f): |
|
|
return f |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ctx.status == ag_ctx.Status.ENABLED: |
|
|
wrapper_factory = convert( |
|
|
recursive=True, user_requested=user_requested, conversion_ctx=ctx) |
|
|
elif ctx.status == ag_ctx.Status.DISABLED: |
|
|
wrapper_factory = do_not_convert |
|
|
elif ctx.status == ag_ctx.Status.UNSPECIFIED: |
|
|
if convert_by_default: |
|
|
wrapper_factory = convert( |
|
|
recursive=True, user_requested=user_requested, conversion_ctx=ctx) |
|
|
else: |
|
|
wrapper_factory = call_with_unspecified_conversion_status |
|
|
else: |
|
|
assert False, 'This switch contains all possible cases!' |
|
|
wrapper = wrapper_factory(f) |
|
|
|
|
|
|
|
|
|
|
|
return autograph_artifact(wrapper) |
|
|
|
|
|
|
|
|
def call_with_unspecified_conversion_status(func): |
|
|
"""Decorator that resets the conversion context to the unspecified status.""" |
|
|
|
|
|
def wrapper(*args, **kwargs): |
|
|
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED): |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
if inspect.isfunction(func) or inspect.ismethod(func): |
|
|
wrapper = functools.update_wrapper(wrapper, func) |
|
|
|
|
|
return autograph_artifact(wrapper) |
|
|
|
|
|
|
|
|
def _log_callargs(f, args, kwargs): |
|
|
"""Logging helper.""" |
|
|
logging.log(2, 'Defaults of %s : %s', f, f.__defaults__) |
|
|
logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__) |
|
|
|
|
|
|
|
|
if kwargs is not None: |
|
|
callargs = inspect.getcallargs(f, *args, **kwargs) |
|
|
else: |
|
|
callargs = inspect.getcallargs(f, *args) |
|
|
|
|
|
formatted_callargs = '\n'.join( |
|
|
' {}: {}'.format(k, v) for k, v in callargs.items()) |
|
|
logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def do_not_convert(func=None): |
|
|
"""Decorator that suppresses the conversion of a function. |
|
|
|
|
|
Args: |
|
|
func: function to decorate. |
|
|
|
|
|
Returns: |
|
|
If `func` is not None, returns a `Callable` which is equivalent to |
|
|
`func`, but is not converted by AutoGraph. |
|
|
If `func` is None, returns a decorator that, when invoked with a |
|
|
single `func` argument, returns a `Callable` equivalent to the |
|
|
above case. |
|
|
""" |
|
|
if func is None: |
|
|
return do_not_convert |
|
|
|
|
|
def wrapper(*args, **kwargs): |
|
|
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
if inspect.isfunction(func) or inspect.ismethod(func): |
|
|
wrapper = functools.update_wrapper(wrapper, func) |
|
|
|
|
|
return autograph_artifact(wrapper) |
|
|
|
|
|
|
|
|
|
|
|
def convert(recursive=False, |
|
|
optional_features=None, |
|
|
user_requested=True, |
|
|
conversion_ctx=ag_ctx.NullCtx()): |
|
|
"""Decorator that compiles a function to use AutoGraph operators. |
|
|
|
|
|
The decorator is dynamic - it recompiles the target whenever the decorated |
|
|
function is called. This means the parameter values are known at conversion. |
|
|
It also means that repeated calls with different types of parameters will be |
|
|
correctly processed. |
|
|
|
|
|
Args: |
|
|
recursive: bool, whether to recursively convert any functions or classes |
|
|
that the converted function may use. |
|
|
optional_features: converted.Feature, allows toggling optional or |
|
|
experimental features. When set to None, only the core features are |
|
|
enabled. |
|
|
user_requested: bool, whether this is a function that the user explicitly |
|
|
asked to be converted. See ConversionOptions.user_requested. |
|
|
conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in |
|
|
which `f` is used. |
|
|
|
|
|
Returns: |
|
|
Callable, a decorator that converts the given function into an equivalent |
|
|
function that uses TensorFlow ops. |
|
|
""" |
|
|
|
|
|
def decorator(f): |
|
|
"""Decorator implementation.""" |
|
|
|
|
|
def wrapper(*args, **kwargs): |
|
|
"""Wrapper that calls the converted version of f.""" |
|
|
options = converter.ConversionOptions( |
|
|
recursive=recursive, |
|
|
user_requested=user_requested, |
|
|
optional_features=optional_features) |
|
|
try: |
|
|
with conversion_ctx: |
|
|
return converted_call(f, args, kwargs, options=options) |
|
|
except Exception as e: |
|
|
if hasattr(e, 'ag_error_metadata'): |
|
|
raise e.ag_error_metadata.to_exception(e) |
|
|
else: |
|
|
raise |
|
|
|
|
|
if inspect.isfunction(f) or inspect.ismethod(f): |
|
|
wrapper = functools.update_wrapper(wrapper, f) |
|
|
|
|
|
|
|
|
return autograph_artifact(wrapper) |
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
def to_graph(entity, recursive=True, experimental_optional_features=None): |
|
|
"""Converts a Python entity into a "Auto-"graph. |
|
|
|
|
|
Also see: `malt.to_code`. |
|
|
|
|
|
Unlike `tf.function`, `to_graph` is a low-level transpiler that converts |
|
|
Python code to TensorFlow graph code. It does not implement any caching, |
|
|
variable management or create any actual ops, and is best used where greater |
|
|
control over the generated TensorFlow graph is desired. Another difference |
|
|
from `tf.function` is that `to_graph` will not wrap the graph into a |
|
|
TensorFlow function or a Python callable. Internally, `tf.function` uses |
|
|
`to_graph`. |
|
|
|
|
|
Example usage: |
|
|
|
|
|
>>> def f(x): |
|
|
... if x > 0: |
|
|
... y = x * x |
|
|
... else: |
|
|
... y = -x |
|
|
... return y |
|
|
... |
|
|
>>> converted_f = to_graph(f) |
|
|
>>> x = tf.constant(2) |
|
|
>>> converted_f(x) # converted_foo is like a TensorFlow Op. |
|
|
<tf.Tensor: shape=(), dtype=int32, numpy=4> |
|
|
|
|
|
Supported Python entities include: |
|
|
* functions |
|
|
* classes |
|
|
* object methods |
|
|
|
|
|
Functions are converted into new functions with converted code. |
|
|
|
|
|
Classes are converted by generating a new class whose methods use converted |
|
|
code. |
|
|
|
|
|
Methods are converted into unbound function that have an additional first |
|
|
argument called `self`. |
|
|
|
|
|
For a tutorial, see the |
|
|
[tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function). |
|
|
For more detailed information, see the |
|
|
[reference documentation](https://github.com/pennylaneai/diastatic-malt/blob/main/malt/g3doc/reference/index.md). |
|
|
|
|
|
Args: |
|
|
entity: Python callable or class to convert. |
|
|
recursive: Whether to recursively convert any functions that the converted |
|
|
function may call. |
|
|
experimental_optional_features: `None`, a tuple of, or a single |
|
|
`tf.autograph.experimental.Feature` value. |
|
|
|
|
|
Returns: |
|
|
Same as `entity`, the converted Python function or class. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the entity could not be converted. |
|
|
""" |
|
|
try: |
|
|
program_ctx = converter.ProgramContext( |
|
|
options=converter.ConversionOptions( |
|
|
recursive=recursive, |
|
|
user_requested=True, |
|
|
optional_features=experimental_optional_features)) |
|
|
return autograph_artifact(_convert_actual(entity, program_ctx)) |
|
|
except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: |
|
|
logging.error(1, 'Error converting %s', entity, exc_info=True) |
|
|
raise ConversionError('converting {}: {}: {}'.format( |
|
|
entity, e.__class__.__name__, str(e))) |
|
|
|
|
|
|
|
|
def to_code(entity, recursive=True, experimental_optional_features=None): |
|
|
"""Returns the source code generated by DiastaticMalt, as a string. |
|
|
|
|
|
Example usage: |
|
|
|
|
|
>>> def f(x): |
|
|
... if x < 0: |
|
|
... x = -x |
|
|
... return x |
|
|
>>> malt.to_code(f) |
|
|
"...def tf__f(x):..." |
|
|
|
|
|
Also see: `malt.to_graph`. |
|
|
|
|
|
Note: If a function has been decorated with `tf.function`, pass its |
|
|
underlying Python function, rather than the callable that `tf.function |
|
|
creates: |
|
|
|
|
|
>>> @tf.function |
|
|
... def f(x): |
|
|
... if x < 0: |
|
|
... x = -x |
|
|
... return x |
|
|
>>> malt.to_code(f.python_function) |
|
|
"...def tf__f(x):..." |
|
|
|
|
|
Args: |
|
|
entity: Python callable or class to convert. |
|
|
recursive: Whether to recursively convert any functions that the converted |
|
|
function may call. |
|
|
experimental_optional_features: `None`, a tuple of, or a single |
|
|
`malt.experimental.Feature` value. |
|
|
|
|
|
Returns: |
|
|
The converted code as string. |
|
|
""" |
|
|
|
|
|
source = inspect.getsource( |
|
|
to_graph( |
|
|
entity, |
|
|
recursive=recursive, |
|
|
experimental_optional_features=experimental_optional_features)) |
|
|
return textwrap.dedent(source) |
|
|
|
|
|
|
|
|
_TRANSPILER = PyToPy() |
|
|
|