File size: 4,618 Bytes
1f5470c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import copy
import importlib
import os
import sys
from keras.src import backend as backend_module
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state
def in_tf_graph():
if global_state.get_global_attribute("in_tf_graph_scope", False):
return True
if "tensorflow" in sys.modules:
from keras.src.utils.module_utils import tensorflow as tf
return not tf.executing_eagerly()
return False
def convert_tf_tensor(outputs, dtype=None):
if backend_module.backend() != "tensorflow" and not in_tf_graph():
outputs = backend_module.convert_to_tensor(outputs, dtype=dtype)
return outputs
class TFGraphScope:
def __init__(self):
self._original_value = global_state.get_global_attribute(
"in_tf_graph_scope", False
)
def __enter__(self):
global_state.set_global_attribute("in_tf_graph_scope", True)
def __exit__(self, *args, **kwargs):
global_state.set_global_attribute(
"in_tf_graph_scope", self._original_value
)
class DynamicBackend:
"""A class that can be used to switch from one backend to another.
Example:
```python
backend = DynamicBackend("tensorflow")
y = backend.square(tf.constant(...))
backend.set_backend("jax")
y = backend.square(jax.numpy.array(...))
```
Args:
backend: Initial backend to use (string).
"""
def __init__(self, backend=None):
self._backend = backend or backend_module.backend()
def set_backend(self, backend):
if backend not in ("tensorflow", "jax", "torch", "numpy", "openvino"):
raise ValueError(
"Available backends are ('tensorflow', 'jax', 'torch', "
f"'numpy' and 'openvino'). Received: backend={backend}"
)
self._backend = backend
def reset(self):
self._backend = backend_module.backend()
@property
def name(self):
return self._backend
def __getattr__(self, name):
if self._backend == "tensorflow":
module = importlib.import_module("keras.src.backend.tensorflow")
return getattr(module, name)
if self._backend == "jax":
module = importlib.import_module("keras.src.backend.jax")
return getattr(module, name)
if self._backend == "torch":
module = importlib.import_module("keras.src.backend.torch")
return getattr(module, name)
if self._backend == "numpy":
if backend_module.backend() == "numpy":
return getattr(backend_module, name)
else:
raise NotImplementedError(
"Currently, we cannot dynamically import the numpy backend "
"because it would disrupt the namespace of the import."
)
if self._backend == "openvino":
module = importlib.import_module("keras.src.backend.openvino")
return getattr(module, name)
@keras_export("keras.config.set_backend")
def set_backend(backend):
"""Reload the backend (and the Keras package).
Example:
```python
keras.config.set_backend("jax")
```
⚠️ WARNING ⚠️: Using this function is dangerous and should be done
carefully. Changing the backend will **NOT** convert
the type of any already-instantiated objects.
Thus, any layers / tensors / etc. already created will no
longer be usable without errors. It is strongly recommended **not**
to keep around **any** Keras-originated objects instances created
before calling `set_backend()`.
This includes any function or class instance that uses any Keras
functionality. All such code needs to be re-executed after calling
`set_backend()`.
"""
os.environ["KERAS_BACKEND"] = backend
# Clear module cache.
loaded_modules = [
key for key in sys.modules.keys() if key.startswith("keras")
]
for key in loaded_modules:
del sys.modules[key]
# Reimport Keras with the new backend (set via KERAS_BACKEND).
import keras
# Finally: refresh all imported Keras submodules.
globs = copy.copy(globals())
for key, value in globs.items():
if value.__class__ == keras.__class__:
if str(value).startswith("<module 'keras."):
module_name = str(value)
module_name = module_name[module_name.find("'") + 1 :]
module_name = module_name[: module_name.find("'")]
globals()[key] = importlib.import_module(module_name)
|