| import tensorflow as tf
|
| import matplotlib.pyplot as plt
|
| import numpy as np
|
| import os
|
| from tf_models import Generator
|
|
|
| def load_image(image_file):
|
| image = tf.io.read_file(image_file)
|
| image = tf.image.decode_jpeg(image, channels=3)
|
| image = tf.image.convert_image_dtype(image, tf.float32)
|
| image = tf.image.resize(image, [256, 256])
|
| image = (image * 2) - 1
|
| return tf.expand_dims(image, 0)
|
|
|
| def predict(model, image_path):
|
| image = load_image(image_path)
|
| prediction = model(image, training=False)
|
|
|
| plt.figure(figsize=(10, 5))
|
|
|
| plt.subplot(1, 2, 1)
|
| plt.title("Input Image")
|
| plt.imshow(image[0] * 0.5 + 0.5)
|
| plt.axis("off")
|
|
|
| plt.subplot(1, 2, 2)
|
| plt.title("Predicted Image")
|
| plt.imshow(prediction[0] * 0.5 + 0.5)
|
| plt.axis("off")
|
|
|
| plt.savefig("tf_prediction.png")
|
| print("Prediction saved to tf_prediction.png")
|
|
|
| def main():
|
| model = Generator()
|
|
|
| potential_weights = ["GeneratorHtoZ.h5", "gen_g_epoch_0.h5"]
|
| loaded = False
|
| for weight_path in potential_weights:
|
| if os.path.exists(weight_path):
|
| try:
|
| model.load_weights(weight_path, by_name=True, skip_mismatch=True)
|
| print(f"Loaded weights from {weight_path}")
|
| loaded = True
|
| break
|
| except Exception as e:
|
| print(f"Could not load {weight_path}: {e}")
|
|
|
| if not loaded:
|
| print("Using untrained model.")
|
|
|
| test_image = "data/horse2zebra/testA/n02381460_1010.jpg"
|
| if os.path.exists(test_image):
|
| predict(model, test_image)
|
| else:
|
| print(f"Test image {test_image} not found.")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|