|
|
import copy |
|
|
import importlib |
|
|
import os |
|
|
import sys |
|
|
|
|
|
from keras.src import backend as backend_module |
|
|
from keras.src.api_export import keras_export |
|
|
from keras.src.backend.common import global_state |
|
|
|
|
|
|
|
|
def in_tf_graph(): |
|
|
if global_state.get_global_attribute("in_tf_graph_scope", False): |
|
|
return True |
|
|
|
|
|
if "tensorflow" in sys.modules: |
|
|
from keras.src.utils.module_utils import tensorflow as tf |
|
|
|
|
|
return not tf.executing_eagerly() |
|
|
return False |
|
|
|
|
|
|
|
|
def convert_tf_tensor(outputs, dtype=None): |
|
|
if backend_module.backend() != "tensorflow" and not in_tf_graph(): |
|
|
outputs = backend_module.convert_to_tensor(outputs, dtype=dtype) |
|
|
return outputs |
|
|
|
|
|
|
|
|
class TFGraphScope: |
|
|
def __init__(self): |
|
|
self._original_value = global_state.get_global_attribute( |
|
|
"in_tf_graph_scope", False |
|
|
) |
|
|
|
|
|
def __enter__(self): |
|
|
global_state.set_global_attribute("in_tf_graph_scope", True) |
|
|
|
|
|
def __exit__(self, *args, **kwargs): |
|
|
global_state.set_global_attribute( |
|
|
"in_tf_graph_scope", self._original_value |
|
|
) |
|
|
|
|
|
|
|
|
class DynamicBackend: |
|
|
"""A class that can be used to switch from one backend to another. |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
backend = DynamicBackend("tensorflow") |
|
|
y = backend.square(tf.constant(...)) |
|
|
backend.set_backend("jax") |
|
|
y = backend.square(jax.numpy.array(...)) |
|
|
``` |
|
|
|
|
|
Args: |
|
|
backend: Initial backend to use (string). |
|
|
""" |
|
|
|
|
|
def __init__(self, backend=None): |
|
|
self._backend = backend or backend_module.backend() |
|
|
|
|
|
def set_backend(self, backend): |
|
|
if backend not in ("tensorflow", "jax", "torch", "numpy", "openvino"): |
|
|
raise ValueError( |
|
|
"Available backends are ('tensorflow', 'jax', 'torch', " |
|
|
f"'numpy' and 'openvino'). Received: backend={backend}" |
|
|
) |
|
|
self._backend = backend |
|
|
|
|
|
def reset(self): |
|
|
self._backend = backend_module.backend() |
|
|
|
|
|
@property |
|
|
def name(self): |
|
|
return self._backend |
|
|
|
|
|
def __getattr__(self, name): |
|
|
if self._backend == "tensorflow": |
|
|
module = importlib.import_module("keras.src.backend.tensorflow") |
|
|
return getattr(module, name) |
|
|
if self._backend == "jax": |
|
|
module = importlib.import_module("keras.src.backend.jax") |
|
|
return getattr(module, name) |
|
|
if self._backend == "torch": |
|
|
module = importlib.import_module("keras.src.backend.torch") |
|
|
return getattr(module, name) |
|
|
if self._backend == "numpy": |
|
|
if backend_module.backend() == "numpy": |
|
|
return getattr(backend_module, name) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
"Currently, we cannot dynamically import the numpy backend " |
|
|
"because it would disrupt the namespace of the import." |
|
|
) |
|
|
if self._backend == "openvino": |
|
|
module = importlib.import_module("keras.src.backend.openvino") |
|
|
return getattr(module, name) |
|
|
|
|
|
|
|
|
@keras_export("keras.config.set_backend") |
|
|
def set_backend(backend): |
|
|
"""Reload the backend (and the Keras package). |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
keras.config.set_backend("jax") |
|
|
``` |
|
|
|
|
|
⚠️ WARNING ⚠️: Using this function is dangerous and should be done |
|
|
carefully. Changing the backend will **NOT** convert |
|
|
the type of any already-instantiated objects. |
|
|
Thus, any layers / tensors / etc. already created will no |
|
|
longer be usable without errors. It is strongly recommended **not** |
|
|
to keep around **any** Keras-originated objects instances created |
|
|
before calling `set_backend()`. |
|
|
|
|
|
This includes any function or class instance that uses any Keras |
|
|
functionality. All such code needs to be re-executed after calling |
|
|
`set_backend()`. |
|
|
""" |
|
|
os.environ["KERAS_BACKEND"] = backend |
|
|
|
|
|
loaded_modules = [ |
|
|
key for key in sys.modules.keys() if key.startswith("keras") |
|
|
] |
|
|
for key in loaded_modules: |
|
|
del sys.modules[key] |
|
|
|
|
|
import keras |
|
|
|
|
|
|
|
|
globs = copy.copy(globals()) |
|
|
for key, value in globs.items(): |
|
|
if value.__class__ == keras.__class__: |
|
|
if str(value).startswith("<module 'keras."): |
|
|
module_name = str(value) |
|
|
module_name = module_name[module_name.find("'") + 1 :] |
|
|
module_name = module_name[: module_name.find("'")] |
|
|
globals()[key] = importlib.import_module(module_name) |
|
|
|