File size: 8,063 Bytes
1a0d68d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
# Copyright 2024 Xanadu Quantum Technologies Inc.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Core conversion logic, serves as main point of access."""
import functools
import inspect
import sys
import unittest
from malt.core import config
from malt.pyct import cache
from malt.pyct import inspect_utils
from malt.utils import ag_logging as logging
_ALLOWLIST_CACHE = cache.UnboundInstanceCache()
def _is_of_known_loaded_module(f, module_name):
mod = sys.modules.get(module_name, None)
if mod is None:
return False
if any(v is not None for v in mod.__dict__.values() if f is v):
return True
return False
def _is_known_loaded_type(f, module_name, entity_name):
"""Tests whether the function or method is an instance of a known type."""
if (module_name not in sys.modules or
not hasattr(sys.modules[module_name], entity_name)):
return False
type_entity = getattr(sys.modules[module_name], entity_name)
if isinstance(f, type_entity):
# The method if of this type. Example:
#
# o = ClassType()
# function(o.method)()
return True
# Note: inspect is required here, to avoid unpacking tf.function decorators.
if inspect.ismethod(f):
# The unbound method if of this type. Example:
#
# class ClassType:
# @function
# def method(self):
# ...
# o = ClassType()
# o.method()
if isinstance(f.__func__, type_entity):
return True
return False
def is_unsupported(o):
"""Checks whether an entity is supported by AutoGraph at all."""
# TODO(b/122265385): Remove this bypass.
if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper') or
_is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')):
logging.warning(
'{} appears to be decorated by wrapt, which is not yet supported'
' by AutoGraph. The function will run as-is.'
' You may still apply AutoGraph before the wrapt decorator.'.format(o))
logging.log(2, 'Permanently allowed: %s: wrapt decorated', o)
return True
if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'):
logging.log(2, 'Permanently allowed: %s: lru_cache', o)
return True
# Constructors are permanently allowed.
# TODO(mdan): Toggle as experimental feature instead.
# TODO(b/124016764): Remove this limitation.
if inspect_utils.isconstructor(o):
logging.log(2, 'Permanently allowed: %s: constructor', o)
return True
# Other built-in modules are permanently allowed.
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
if any(
_is_of_known_loaded_module(o, m)
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
logging.log(2, 'Permanently allowed: %s: part of builtin module', o)
return True
# Custom ops and kernels are also permanently allowed.
# See tensorflow.framework.load_library.
if (hasattr(o, '__module__') and
hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')):
logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o)
return True
return False
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
def is_allowlisted(
o, check_call_override=True, allow_namedtuple_subclass=False):
"""Checks whether an entity is allowed for use in graph mode.
Examples of allowed entities include all members of the tensorflow
package.
Args:
o: A Python entity.
check_call_override: Reserved for internal use. When set to `False`, it
disables the rule according to which classes are allowed if their
__call__ method is allowed.
allow_namedtuple_subclass: Reserved for internal use. When `True`,
namedtuple subclasses are not allowed.
Returns:
Boolean
"""
# TODO(b/120224672): Fix this.
if isinstance(o, functools.partial):
# tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
# functools.partial objects do not have a __module__ attribute.
m = functools
else:
# (dime10) replacement for tf_inspect.getmodule
m = inspect.getmodule(o)
# Examples of callables that lack a __module__ property include builtins.
if hasattr(m, '__name__'):
for rule in config.CONVERSION_RULES:
action = rule.get_action(m)
if action == config.Action.CONVERT:
logging.log(2, 'Not allowed: %s: %s', o, rule)
return False
elif action == config.Action.DO_NOT_CONVERT:
logging.log(2, 'Allowlisted: %s: %s', o, rule)
return True
# The check for __code__ below is because isgeneratorfunction crashes
# without one.
# (dime10) replacement for tf_inspect.isgeneratorfunction
if hasattr(o, '__code__') and inspect.isgeneratorfunction(o):
logging.log(2, 'Allowlisted: %s: generator functions are not converted', o)
return True
# (dime10) replacement for tf_inspect.isclass
if (check_call_override and not inspect.isclass(o) and
hasattr(o, '__call__')):
# Callable objects: allowed if their __call__ method is.
# The type check avoids infinite recursion around the __call__ method
# of function objects.
if (type(o) != type(o.__call__)) and is_allowlisted(o.__call__): # pylint: disable=unidiomatic-typecheck
logging.log(2, 'Allowlisted: %s: object __call__ allowed', o)
return True
owner_class = None
# (dime10) replacement for tf_inspect.ismethod
if inspect.ismethod(o):
# Methods of allowed classes are also allowed, even if they are
# bound via user subclasses.
#
# For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
# defined as below. `tf.Foo` is allowed. Then `baz.bar` is also
# allowed.
#
# class Custom(tf.Foo):
# pass
#
# baz = Custom()
#
# For the example above, if `Custom` did overload `bar`, then it would no
# longer be allowed.
owner_class = inspect_utils.getmethodclass(o)
# (dime10) strip TfMethodTarget case
if owner_class is not None:
if issubclass(owner_class, unittest.TestCase):
logging.log(2, 'Allowlisted: %s: method of TestCase subclass', o)
return True
owner_class = inspect_utils.getdefiningclass(o, owner_class)
if is_allowlisted(
owner_class,
check_call_override=False,
allow_namedtuple_subclass=True):
logging.log(2, 'Allowlisted: %s: owner is allowed %s', o,
owner_class)
return True
if inspect_utils.isnamedtuple(o):
# Due to the way they're constructed, namedtuple types cannot be converted
# because they don't expose source code. But we assume they are safe for
# graph mode since they are just containers.
if allow_namedtuple_subclass:
if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__):
logging.log(2, 'Allowlisted: %s: named tuple', o)
return True
else:
logging.log(2, 'Allowlisted: %s: named tuple or subclass', o)
return True
logging.log(2, 'Not allowed: %s: default rule', o)
return False
def is_in_allowlist_cache(entity, options):
try:
return _ALLOWLIST_CACHE.has(entity, options)
except TypeError:
# Catch-all for entities that are unhashable or don't allow weakrefs.
return False
def cache_allowlisted(entity, options):
try:
_ALLOWLIST_CACHE[entity][options] = True
except TypeError:
# Catch-all for entities that are unhashable or don't allow weakrefs.
pass
|