kerzel commited on
Commit
45a9fe7
·
1 Parent(s): f45fd07

convert models to keras format

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.png filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.keras filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc*
check_model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow import keras
2
+ import numpy as np
3
+
4
+
5
+ # import tensorflow as tf
6
+
7
+ # loaded = tf.saved_model.load('rwthmaterials_dp800_network1_inclusion')
8
+ # print("Available endpoints:", list(loaded.signatures.keys()))
9
+
10
+ # Load the model
11
+
12
+ model = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.keras')
13
+
14
+ # Inspect model inputs and outputs
15
+ print("Model Summary:")
16
+ model.summary()
17
+
18
+ print("Inputs:")
19
+ for i, input_tensor in enumerate(model.inputs):
20
+ print(f"Input {i+1}: name={input_tensor.name}, shape={input_tensor.shape}")
21
+
22
+ print("Outputs:")
23
+ for i, output_tensor in enumerate(model.outputs):
24
+ print(f"Output {i+1}: name={output_tensor.name}, shape={output_tensor.shape}")
25
+
26
+ # Generate a wrapper function based on input count
27
+ def generate_wrapper(model):
28
+ def wrapper(*args):
29
+ # Convert inputs to numpy arrays and reshape if needed
30
+ processed_inputs = []
31
+ for i, input_tensor in enumerate(model.inputs):
32
+ shape = input_tensor.shape
33
+ # Replace None with 1 for batch dimension
34
+ input_shape = [dim if dim is not None else 1 for dim in shape]
35
+ arr = np.array(args[i]).reshape(input_shape)
36
+ processed_inputs.append(arr)
37
+ # Predict
38
+ prediction = model.predict(processed_inputs)
39
+ return prediction.tolist()
40
+ return wrapper
41
+
42
+ # Create the wrapper
43
+ predict_fn = generate_wrapper(model)
44
+
45
+ # Example usage with dummy data
46
+ # Replace with actual input data when integrating with Gradio
47
+ # dummy_input1 = np.random.rand(1, 6, 6, 2048)
48
+ # dummy_input2 = np.random.rand(1, 6, 6, 2048)
49
+ # print(predict_fn(dummy_input1, dummy_input2))
rwthmaterials_dp800_network1_inclusion (1).keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a3b68e773015cb1e6784892a05781125cf68ce9faa1286bd2c9b100f8f80f27
3
+ size 88190917
rwthmaterials_dp800_network1_inclusion.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a3b68e773015cb1e6784892a05781125cf68ce9faa1286bd2c9b100f8f80f27
3
+ size 88190917