File size: 1,823 Bytes
d1bfee5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
from tf_models import Generator
def load_image(image_file):
image = tf.io.read_file(image_file)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [256, 256])
image = (image * 2) - 1
return tf.expand_dims(image, 0)
def predict(model, image_path):
image = load_image(image_path)
prediction = model(image, training=False)
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Input Image")
plt.imshow(image[0] * 0.5 + 0.5)
plt.axis("off")
plt.subplot(1, 2, 2)
plt.title("Predicted Image")
plt.imshow(prediction[0] * 0.5 + 0.5)
plt.axis("off")
plt.savefig("tf_prediction.png")
print("Prediction saved to tf_prediction.png")
def main():
model = Generator()
# Attempt to load existing .h5 files if they exist
potential_weights = ["GeneratorHtoZ.h5", "gen_g_epoch_0.h5"]
loaded = False
for weight_path in potential_weights:
if os.path.exists(weight_path):
try:
model.load_weights(weight_path, by_name=True, skip_mismatch=True)
print(f"Loaded weights from {weight_path}")
loaded = True
break
except Exception as e:
print(f"Could not load {weight_path}: {e}")
if not loaded:
print("Using untrained model.")
test_image = "data/horse2zebra/testA/n02381460_1010.jpg"
if os.path.exists(test_image):
predict(model, test_image)
else:
print(f"Test image {test_image} not found.")
if __name__ == "__main__":
main()
|