| from tensorflow import keras | |
| import numpy as np | |
| # import tensorflow as tf | |
| # loaded = tf.saved_model.load('rwthmaterials_dp800_network1_inclusion') | |
| # print("Available endpoints:", list(loaded.signatures.keys())) | |
| # Load the model | |
| model = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.keras') | |
| # Inspect model inputs and outputs | |
| print("Model Summary:") | |
| model.summary() | |
| print("Inputs:") | |
| for i, input_tensor in enumerate(model.inputs): | |
| print(f"Input {i+1}: name={input_tensor.name}, shape={input_tensor.shape}") | |
| print("Outputs:") | |
| for i, output_tensor in enumerate(model.outputs): | |
| print(f"Output {i+1}: name={output_tensor.name}, shape={output_tensor.shape}") | |
| # Generate a wrapper function based on input count | |
| def generate_wrapper(model): | |
| def wrapper(*args): | |
| # Convert inputs to numpy arrays and reshape if needed | |
| processed_inputs = [] | |
| for i, input_tensor in enumerate(model.inputs): | |
| shape = input_tensor.shape | |
| # Replace None with 1 for batch dimension | |
| input_shape = [dim if dim is not None else 1 for dim in shape] | |
| arr = np.array(args[i]).reshape(input_shape) | |
| processed_inputs.append(arr) | |
| # Predict | |
| prediction = model.predict(processed_inputs) | |
| return prediction.tolist() | |
| return wrapper | |
| # Create the wrapper | |
| predict_fn = generate_wrapper(model) | |
| # Example usage with dummy data | |
| # Replace with actual input data when integrating with Gradio | |
| # dummy_input1 = np.random.rand(1, 6, 6, 2048) | |
| # dummy_input2 = np.random.rand(1, 6, 6, 2048) | |
| # print(predict_fn(dummy_input1, dummy_input2)) | |