import tensorflow as tf import matplotlib.pyplot as plt import numpy as np import os from tf_models import Generator def load_image(image_file): image = tf.io.read_file(image_file) image = tf.image.decode_jpeg(image, channels=3) image = tf.image.convert_image_dtype(image, tf.float32) image = tf.image.resize(image, [256, 256]) image = (image * 2) - 1 return tf.expand_dims(image, 0) def predict(model, image_path): image = load_image(image_path) prediction = model(image, training=False) plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.title("Input Image") plt.imshow(image[0] * 0.5 + 0.5) plt.axis("off") plt.subplot(1, 2, 2) plt.title("Predicted Image") plt.imshow(prediction[0] * 0.5 + 0.5) plt.axis("off") plt.savefig("tf_prediction.png") print("Prediction saved to tf_prediction.png") def main(): model = Generator() # Attempt to load existing .h5 files if they exist potential_weights = ["GeneratorHtoZ.h5", "gen_g_epoch_0.h5"] loaded = False for weight_path in potential_weights: if os.path.exists(weight_path): try: model.load_weights(weight_path, by_name=True, skip_mismatch=True) print(f"Loaded weights from {weight_path}") loaded = True break except Exception as e: print(f"Could not load {weight_path}: {e}") if not loaded: print("Using untrained model.") test_image = "data/horse2zebra/testA/n02381460_1010.jpg" if os.path.exists(test_image): predict(model, test_image) else: print(f"Test image {test_image} not found.") if __name__ == "__main__": main()