Update trjgru.py
Browse files
trjgru.py
CHANGED
|
@@ -285,16 +285,25 @@ import h5py
|
|
| 285 |
|
| 286 |
model = build_combined_model() # Your original model building function
|
| 287 |
# Rebuild the model architecture
|
| 288 |
-
model
|
|
|
|
| 289 |
|
| 290 |
-
#
|
| 291 |
-
|
| 292 |
-
_ = model(dummy_input) # Forward pass to build all layers
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
-
model.load_weights(r"Trj_GRU.weights.h5")
|
| 298 |
|
| 299 |
|
| 300 |
def predict_trajgru(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):
|
|
|
|
| 285 |
|
| 286 |
model = build_combined_model() # Your original model building function
|
| 287 |
# Rebuild the model architecture
|
| 288 |
+
# Step 1: Build the full combined model (with 6 inputs)
|
| 289 |
+
# model = build_combined_model()
|
| 290 |
|
| 291 |
+
# Step 2: Call the model once with dummy data to build the weights
|
| 292 |
+
# import tensorflow as tf
|
|
|
|
| 293 |
|
| 294 |
+
dummy_input = [
|
| 295 |
+
tf.random.normal((1, 8, 95, 95, 2)), # reduced_images_test
|
| 296 |
+
tf.random.normal((1, 95, 95, 8)), # hov_m_test
|
| 297 |
+
tf.random.normal((1, 8, 8, 1)), # test_vmax_3d
|
| 298 |
+
tf.random.normal((1, 8)), # lat_test
|
| 299 |
+
tf.random.normal((1, 8)), # lon_test
|
| 300 |
+
tf.random.normal((1, 9)), # other_scalar_inputs
|
| 301 |
+
]
|
| 302 |
+
_ = model(dummy_input) # Build model by doing one forward pass
|
| 303 |
+
|
| 304 |
+
# Step 3: Load weights
|
| 305 |
+
model.load_weights("Trj_GRU.weights.h5") # Make sure this matches the architecture
|
| 306 |
|
|
|
|
| 307 |
|
| 308 |
|
| 309 |
def predict_trajgru(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):
|