convert models to keras format
Browse files- .gitattributes +1 -0
- .gitignore +1 -0
- check_model.py +49 -0
- rwthmaterials_dp800_network1_inclusion (1).keras +3 -0
- rwthmaterials_dp800_network1_inclusion.keras +3 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.keras filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pyc*
|
check_model.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tensorflow import keras
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# import tensorflow as tf
|
| 6 |
+
|
| 7 |
+
# loaded = tf.saved_model.load('rwthmaterials_dp800_network1_inclusion')
|
| 8 |
+
# print("Available endpoints:", list(loaded.signatures.keys()))
|
| 9 |
+
|
| 10 |
+
# Load the model
|
| 11 |
+
|
| 12 |
+
model = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.keras')
|
| 13 |
+
|
| 14 |
+
# Inspect model inputs and outputs
|
| 15 |
+
print("Model Summary:")
|
| 16 |
+
model.summary()
|
| 17 |
+
|
| 18 |
+
print("Inputs:")
|
| 19 |
+
for i, input_tensor in enumerate(model.inputs):
|
| 20 |
+
print(f"Input {i+1}: name={input_tensor.name}, shape={input_tensor.shape}")
|
| 21 |
+
|
| 22 |
+
print("Outputs:")
|
| 23 |
+
for i, output_tensor in enumerate(model.outputs):
|
| 24 |
+
print(f"Output {i+1}: name={output_tensor.name}, shape={output_tensor.shape}")
|
| 25 |
+
|
| 26 |
+
# Generate a wrapper function based on input count
|
| 27 |
+
def generate_wrapper(model):
|
| 28 |
+
def wrapper(*args):
|
| 29 |
+
# Convert inputs to numpy arrays and reshape if needed
|
| 30 |
+
processed_inputs = []
|
| 31 |
+
for i, input_tensor in enumerate(model.inputs):
|
| 32 |
+
shape = input_tensor.shape
|
| 33 |
+
# Replace None with 1 for batch dimension
|
| 34 |
+
input_shape = [dim if dim is not None else 1 for dim in shape]
|
| 35 |
+
arr = np.array(args[i]).reshape(input_shape)
|
| 36 |
+
processed_inputs.append(arr)
|
| 37 |
+
# Predict
|
| 38 |
+
prediction = model.predict(processed_inputs)
|
| 39 |
+
return prediction.tolist()
|
| 40 |
+
return wrapper
|
| 41 |
+
|
| 42 |
+
# Create the wrapper
|
| 43 |
+
predict_fn = generate_wrapper(model)
|
| 44 |
+
|
| 45 |
+
# Example usage with dummy data
|
| 46 |
+
# Replace with actual input data when integrating with Gradio
|
| 47 |
+
# dummy_input1 = np.random.rand(1, 6, 6, 2048)
|
| 48 |
+
# dummy_input2 = np.random.rand(1, 6, 6, 2048)
|
| 49 |
+
# print(predict_fn(dummy_input1, dummy_input2))
|
rwthmaterials_dp800_network1_inclusion (1).keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3a3b68e773015cb1e6784892a05781125cf68ce9faa1286bd2c9b100f8f80f27
|
| 3 |
+
size 88190917
|
rwthmaterials_dp800_network1_inclusion.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3a3b68e773015cb1e6784892a05781125cf68ce9faa1286bd2c9b100f8f80f27
|
| 3 |
+
size 88190917
|