Spaces:
Sleeping
Sleeping
File size: 2,738 Bytes
b197185 af6288c b197185 af6288c b197185 af6288c b197185 af6288c b197185 af6288c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | import sys
from types import ModuleType
import os
def patch_tensorflow_with_keras():
"""
Monkey-patches the 'tensorflow' module to use 'keras' (Keras 3) with a custom backend (like torch).
This allows libraries like 'fer' to work without needing the full 'tensorflow' package.
"""
if "tensorflow" in sys.modules and not isinstance(sys.modules["tensorflow"], ModuleType):
# Already patched or real tensorflow is already there
return
try:
# Set Keras backend to torch if not already set
if "KERAS_BACKEND" not in os.environ:
os.environ["KERAS_BACKEND"] = "torch"
import keras
# Create a dummy tensorflow module
tf = ModuleType("tensorflow")
sys.modules["tensorflow"] = tf
# Map keras to tensorflow.keras
sys.modules["tensorflow.keras"] = keras
tf.keras = keras
# Create dummy python submodule
tf_python = ModuleType("tensorflow.python")
sys.modules["tensorflow.python"] = tf_python
tf.python = tf_python
# Map keras to tensorflow.python.keras
sys.modules["tensorflow.python.keras"] = keras
tf_python.keras = keras
# Map common submodules explicitly
sub_modules = [
"models", "layers", "backend", "utils", "callbacks",
"initializers", "optimizers", "regularizers", "constraints", "activations"
]
for sub in sub_modules:
try:
# Try to get from keras
module = getattr(keras, sub, None)
if not module:
# Try to import directly
import_name = f"keras.{sub}"
__import__(import_name)
module = sys.modules[import_name]
if module:
sys.modules[f"tensorflow.keras.{sub}"] = module
sys.modules[f"tensorflow.python.keras.{sub}"] = module
setattr(tf.keras, sub, module)
setattr(tf.python.keras, sub, module)
except (ImportError, AttributeError):
pass
# Add some dummy compat modules if needed
tf_compat = ModuleType("tensorflow.compat")
sys.modules["tensorflow.compat"] = tf_compat
tf.compat = tf_compat
tf_v1 = ModuleType("tensorflow.compat.v1")
sys.modules["tensorflow.compat.v1"] = tf_v1
tf_compat.v1 = tf_v1
print("Successfully monkey-patched tensorflow with keras.")
except Exception as e:
print(f"Warning: Failed to monkey-patch tensorflow with keras: {e}")
|