DP800_DamageClassification / check_model.py.saved
kerzel's picture
rename check_model so it is not executed
dc3b52a
from tensorflow import keras
import numpy as np
# import tensorflow as tf
# loaded = tf.saved_model.load('rwthmaterials_dp800_network1_inclusion')
# print("Available endpoints:", list(loaded.signatures.keys()))
# Load the model
model = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.keras')
# Inspect model inputs and outputs
print("Model Summary:")
model.summary()
print("Inputs:")
for i, input_tensor in enumerate(model.inputs):
print(f"Input {i+1}: name={input_tensor.name}, shape={input_tensor.shape}")
print("Outputs:")
for i, output_tensor in enumerate(model.outputs):
print(f"Output {i+1}: name={output_tensor.name}, shape={output_tensor.shape}")
# Generate a wrapper function based on input count
def generate_wrapper(model):
def wrapper(*args):
# Convert inputs to numpy arrays and reshape if needed
processed_inputs = []
for i, input_tensor in enumerate(model.inputs):
shape = input_tensor.shape
# Replace None with 1 for batch dimension
input_shape = [dim if dim is not None else 1 for dim in shape]
arr = np.array(args[i]).reshape(input_shape)
processed_inputs.append(arr)
# Predict
prediction = model.predict(processed_inputs)
return prediction.tolist()
return wrapper
# Create the wrapper
predict_fn = generate_wrapper(model)
# Example usage with dummy data
# Replace with actual input data when integrating with Gradio
# dummy_input1 = np.random.rand(1, 6, 6, 2048)
# dummy_input2 = np.random.rand(1, 6, 6, 2048)
# print(predict_fn(dummy_input1, dummy_input2))