File size: 1,628 Bytes
45a9fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from tensorflow import keras
import numpy as np


# import tensorflow as tf

# loaded = tf.saved_model.load('rwthmaterials_dp800_network1_inclusion')
# print("Available endpoints:", list(loaded.signatures.keys()))

# Load the model

model = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.keras')

# Inspect model inputs and outputs
print("Model Summary:")
model.summary()

print("Inputs:")
for i, input_tensor in enumerate(model.inputs):
    print(f"Input {i+1}: name={input_tensor.name}, shape={input_tensor.shape}")

print("Outputs:")
for i, output_tensor in enumerate(model.outputs):
    print(f"Output {i+1}: name={output_tensor.name}, shape={output_tensor.shape}")

# Generate a wrapper function based on input count
def generate_wrapper(model):
    def wrapper(*args):
        # Convert inputs to numpy arrays and reshape if needed
        processed_inputs = []
        for i, input_tensor in enumerate(model.inputs):
            shape = input_tensor.shape
            # Replace None with 1 for batch dimension
            input_shape = [dim if dim is not None else 1 for dim in shape]
            arr = np.array(args[i]).reshape(input_shape)
            processed_inputs.append(arr)
        # Predict
        prediction = model.predict(processed_inputs)
        return prediction.tolist()
    return wrapper

# Create the wrapper
predict_fn = generate_wrapper(model)

# Example usage with dummy data
# Replace with actual input data when integrating with Gradio
# dummy_input1 = np.random.rand(1, 6, 6, 2048)
# dummy_input2 = np.random.rand(1, 6, 6, 2048)
# print(predict_fn(dummy_input1, dummy_input2))