File size: 1,823 Bytes
d1bfee5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
from tf_models import Generator

def load_image(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [256, 256])
    image = (image * 2) - 1
    return tf.expand_dims(image, 0)

def predict(model, image_path):
    image = load_image(image_path)
    prediction = model(image, training=False)
    
    plt.figure(figsize=(10, 5))
    
    plt.subplot(1, 2, 1)
    plt.title("Input Image")
    plt.imshow(image[0] * 0.5 + 0.5)
    plt.axis("off")
    
    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(prediction[0] * 0.5 + 0.5)
    plt.axis("off")
    
    plt.savefig("tf_prediction.png")
    print("Prediction saved to tf_prediction.png")

def main():
    model = Generator()
    # Attempt to load existing .h5 files if they exist
    potential_weights = ["GeneratorHtoZ.h5", "gen_g_epoch_0.h5"]
    loaded = False
    for weight_path in potential_weights:
        if os.path.exists(weight_path):
            try:
                model.load_weights(weight_path, by_name=True, skip_mismatch=True)
                print(f"Loaded weights from {weight_path}")
                loaded = True
                break
            except Exception as e:
                print(f"Could not load {weight_path}: {e}")
    
    if not loaded:
        print("Using untrained model.")
        
    test_image = "data/horse2zebra/testA/n02381460_1010.jpg"
    if os.path.exists(test_image):
        predict(model, test_image)
    else:
        print(f"Test image {test_image} not found.")

if __name__ == "__main__":
    main()