Reyal commited on
Commit
25c4e8b
·
verified ·
1 Parent(s): 7bd40f7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +35 -9
src/streamlit_app.py CHANGED
@@ -57,23 +57,49 @@ GEMMA_PATH = "/app/gemma_tomato_lora"
57
  # ======================
58
  @st.cache_resource
59
  def load_cnn():
60
- from keras.src.engine.input_layer import InputLayer
 
 
61
 
62
- original_from_config = InputLayer.from_config
 
 
 
63
 
64
  @classmethod
65
- def patched_from_config(cls, config):
66
  config.pop("batch_shape", None)
67
  config.pop("optional", None)
68
- return original_from_config(config)
69
 
70
- InputLayer.from_config = patched_from_config
71
 
72
- return tf.keras.models.load_model(
73
- MODEL_PATH,
74
- compile=False
75
- )
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  cnn_model = load_cnn()
78
 
79
  # ======================
 
57
  # ======================
58
  @st.cache_resource
59
  def load_cnn():
60
+ import tensorflow as tf
61
+ from tensorflow.keras.layers import InputLayer, Rescaling
62
+ from tensorflow.keras.utils import custom_object_scope
63
 
64
+ # -------------------------
65
+ # PATCH InputLayer
66
+ # -------------------------
67
+ original_input = InputLayer.from_config
68
 
69
  @classmethod
70
+ def patched_input(cls, config):
71
  config.pop("batch_shape", None)
72
  config.pop("optional", None)
73
+ return original_input(config)
74
 
75
+ InputLayer.from_config = patched_input
76
 
77
+ # -------------------------
78
+ # PATCH Rescaling
79
+ # -------------------------
80
+ original_rescaling = Rescaling.from_config
81
 
82
+ @classmethod
83
+ def patched_rescaling(cls, config):
84
+ if isinstance(config.get("dtype"), dict):
85
+ config["dtype"] = "float32"
86
+ return original_rescaling(config)
87
+
88
+ Rescaling.from_config = patched_rescaling
89
+
90
+ # -------------------------
91
+ # LOAD MODEL
92
+ # -------------------------
93
+ with custom_object_scope({
94
+ "Rescaling": Rescaling,
95
+ "InputLayer": InputLayer
96
+ }):
97
+ model = tf.keras.models.load_model(
98
+ MODEL_PATH,
99
+ compile=False
100
+ )
101
+
102
+ return model
103
  cnn_model = load_cnn()
104
 
105
  # ======================