File size: 2,738 Bytes
b197185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af6288c
 
 
 
b197185
af6288c
 
 
b197185
af6288c
 
 
 
 
b197185
af6288c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b197185
 
 
 
af6288c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sys
from types import ModuleType
import os

def patch_tensorflow_with_keras():
    """
    Monkey-patches the 'tensorflow' module to use 'keras' (Keras 3) with a custom backend (like torch).
    This allows libraries like 'fer' to work without needing the full 'tensorflow' package.
    """
    if "tensorflow" in sys.modules and not isinstance(sys.modules["tensorflow"], ModuleType):
        # Already patched or real tensorflow is already there
        return

    try:
        # Set Keras backend to torch if not already set
        if "KERAS_BACKEND" not in os.environ:
            os.environ["KERAS_BACKEND"] = "torch"
            
        import keras
        
        # Create a dummy tensorflow module
        tf = ModuleType("tensorflow")
        sys.modules["tensorflow"] = tf
        
        # Map keras to tensorflow.keras
        sys.modules["tensorflow.keras"] = keras
        tf.keras = keras
        
        # Create dummy python submodule
        tf_python = ModuleType("tensorflow.python")
        sys.modules["tensorflow.python"] = tf_python
        tf.python = tf_python
        
        # Map keras to tensorflow.python.keras
        sys.modules["tensorflow.python.keras"] = keras
        tf_python.keras = keras

        # Map common submodules explicitly
        sub_modules = [
            "models", "layers", "backend", "utils", "callbacks", 
            "initializers", "optimizers", "regularizers", "constraints", "activations"
        ]
        
        for sub in sub_modules:
            try:
                # Try to get from keras
                module = getattr(keras, sub, None)
                if not module:
                    # Try to import directly
                    import_name = f"keras.{sub}"
                    __import__(import_name)
                    module = sys.modules[import_name]
                
                if module:
                    sys.modules[f"tensorflow.keras.{sub}"] = module
                    sys.modules[f"tensorflow.python.keras.{sub}"] = module
                    setattr(tf.keras, sub, module)
                    setattr(tf.python.keras, sub, module)
            except (ImportError, AttributeError):
                pass
            
        # Add some dummy compat modules if needed
        tf_compat = ModuleType("tensorflow.compat")
        sys.modules["tensorflow.compat"] = tf_compat
        tf.compat = tf_compat
        
        tf_v1 = ModuleType("tensorflow.compat.v1")
        sys.modules["tensorflow.compat.v1"] = tf_v1
        tf_compat.v1 = tf_v1
            
        print("Successfully monkey-patched tensorflow with keras.")
    except Exception as e:
        print(f"Warning: Failed to monkey-patch tensorflow with keras: {e}")