File size: 2,177 Bytes
078ce08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
import sys
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.models import load_model
from PIL import Image
import matplotlib.pyplot as plt
from random import randint
import imageio
from skimage.util.shape import view_as_blocks

'''
Test the model on sample images (unseen)
Plot the input and output images
'''

# Load test images
test_images = np.load(sys.argv[1])

# Load model
model = load_model(sys.argv[2], compile=False)


# Normalize inputs
def normalize_batch(imgs):
    """ Performs channel-wise z-score normalization """

    return (imgs - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])


# Denormalize outputs
def denormalize_batch(imgs, should_clip=True):
    imgs = (imgs * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])

    if should_clip:
        imgs = np.clip(imgs, 0, 1)
    return imgs


# Load images as batch (batch size -4)
secretin = test_images[np.random.choice(len(test_images), size=4, replace=False)]
coverin = test_images[np.random.choice(len(test_images), size=4, replace=False)]

# Perform batch prediction
coverout, secretout = model.predict([normalize_batch(secretin), normalize_batch(coverin)])

# Postprocess cover output
coverout = denormalize_batch(coverout)
coverout = np.squeeze(coverout) * 255.0
coverout = np.uint8(coverout)

# Postprocess secret output
secretout = denormalize_batch(secretout)
secretout = np.squeeze(secretout) * 255.0
secretout = np.uint8(secretout)

# Convert images to UINT8 format (0-255)
coverin = np.uint8(np.squeeze(coverin * 255.0))
secretin = np.uint8(np.squeeze(secretin * 255.0))


# Plot the images
def plot(im, title):
    fig = plt.figure(figsize=(20, 20))

    for i in range(4):
        sub = fig.add_subplot(1, 4, i + 1)
        sub.title.set_text(title + " " + str(i + 1))
        sub.imshow(im[i, :, :, :])


# Plot secret input and output
plot(secretin, "Secret Input")
plot(secretout, "Secret Output")

# Plot cover input and output
plot(coverin, "Cover Input")
plot(coverout, "Cover Output")

# Sample run: python test.py test/testdata.npy checkpoints/steg_model-06-0.03.hdf5