smashfix-v1 / debug_model_load.py
uncertainrods's picture
version_check
62ca666
import os
import tensorflow as tf
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DebugModelLoad")
model_path = os.path.join("models", "tcn_hybrid_tuned.h5")
print(f"TF Version: {tf.__version__}")
print(f"Keras Version: {tf.keras.__version__}")
def try_load_standard():
print("\n--- Attempting standard load_model ---")
try:
model = tf.keras.models.load_model(model_path)
print("Standard load success!")
return True
except TypeError as e:
print(f"Standard load failed as expected: {e}")
return False
except Exception as e:
print(f"Standard load failed with unexpected error: {e}")
return False
def try_load_with_fix():
print("\n--- Attempting load_model with FixedInputLayer ---")
class FixedInputLayer(tf.keras.layers.InputLayer):
def __init__(self, batch_shape=None, **kwargs):
# If batch_shape is present, convert it to batch_input_shape
# or just pass it if the super handles it differently,
# but usually 'batch_solution' is the issue.
# Keras 2.15 InputLayer expects batch_input_shape typically.
if batch_shape is not None:
# If batch_input_shape is not already set, use batch_shape
if 'batch_input_shape' not in kwargs:
kwargs['batch_input_shape'] = batch_shape
# Remove batch_shape to avoid the "Unrecognized keyword argument" error
# if the superclass or subsequent logic doesn't like it.
super().__init__(**kwargs)
try:
model = tf.keras.models.load_model(model_path, custom_objects={'InputLayer': FixedInputLayer})
print("Fix load success!")
model.summary()
return True
except Exception as e:
print(f"Fix load failed: {e}")
return False
if __name__ == "__main__":
if not os.path.exists(model_path):
print(f"Error: Model file not found at {model_path}")
else:
if not try_load_standard():
try_load_with_fix()