Spaces:
Sleeping
Sleeping
| import os | |
| import tensorflow as tf | |
| import logging | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("DebugModelLoad") | |
| model_path = os.path.join("models", "tcn_hybrid_tuned.h5") | |
| print(f"TF Version: {tf.__version__}") | |
| print(f"Keras Version: {tf.keras.__version__}") | |
| def try_load_standard(): | |
| print("\n--- Attempting standard load_model ---") | |
| try: | |
| model = tf.keras.models.load_model(model_path) | |
| print("Standard load success!") | |
| return True | |
| except TypeError as e: | |
| print(f"Standard load failed as expected: {e}") | |
| return False | |
| except Exception as e: | |
| print(f"Standard load failed with unexpected error: {e}") | |
| return False | |
| def try_load_with_fix(): | |
| print("\n--- Attempting load_model with FixedInputLayer ---") | |
| class FixedInputLayer(tf.keras.layers.InputLayer): | |
| def __init__(self, batch_shape=None, **kwargs): | |
| # If batch_shape is present, convert it to batch_input_shape | |
| # or just pass it if the super handles it differently, | |
| # but usually 'batch_solution' is the issue. | |
| # Keras 2.15 InputLayer expects batch_input_shape typically. | |
| if batch_shape is not None: | |
| # If batch_input_shape is not already set, use batch_shape | |
| if 'batch_input_shape' not in kwargs: | |
| kwargs['batch_input_shape'] = batch_shape | |
| # Remove batch_shape to avoid the "Unrecognized keyword argument" error | |
| # if the superclass or subsequent logic doesn't like it. | |
| super().__init__(**kwargs) | |
| try: | |
| model = tf.keras.models.load_model(model_path, custom_objects={'InputLayer': FixedInputLayer}) | |
| print("Fix load success!") | |
| model.summary() | |
| return True | |
| except Exception as e: | |
| print(f"Fix load failed: {e}") | |
| return False | |
| if __name__ == "__main__": | |
| if not os.path.exists(model_path): | |
| print(f"Error: Model file not found at {model_path}") | |
| else: | |
| if not try_load_standard(): | |
| try_load_with_fix() | |