File size: 4,745 Bytes
1ff6b95
 
 
 
 
 
 
 
 
 
4637a55
1ff6b95
 
 
4637a55
1ff6b95
 
 
 
 
 
 
 
 
 
4637a55
 
 
 
1ff6b95
 
 
4637a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ff6b95
 
4637a55
 
 
 
1ff6b95
 
4637a55
1ff6b95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4637a55
1ff6b95
 
4637a55
 
 
 
 
1ff6b95
 
 
 
4637a55
 
1ff6b95
 
4637a55
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Import all necessary libs
import tensorflow as tf
import keras
import numpy as np
import cv2
import gradio as gr
import os

# Load model
# Load Pix2Pix
pix2pix_path = './model/wt_generator_best.keras'
Pix2Pix = keras.saving.load_model(pix2pix_path)

# Load MS-UNet
unet_path = './model/unet.keras'
MS_UNet = keras.saving.load_model(unet_path)

# Load WD-Net
@keras.saving.register_keras_serializable(package="Clip")
class Clip(keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
  def call(self, input):
    return tf.clip_by_value(input, 0, 1)

old_gen_path = './model/generator_epoch_10.keras'
WD_Net_old = keras.saving.load_model(old_gen_path)
new_gen_path = './model/WD-Net_generator.keras'
WD_Net_new = keras.saving.load_model(new_gen_path)

# Define infer function
def infer(img, model='WD-Net'):
  # Read image
  # img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3)
  # Image original shape
  org_shape = img.shape
  org_img = tf.image.resize(img, [256, 256], method='area')
  org_img = tf.cast(org_img, tf.uint8).numpy()
  org_img = cv2.resize(org_img, (org_shape[1], org_shape[0]))

  # Choose model
  if model == 'WD-Net':
    generator = WD_Net_old
    img = tf.image.resize(img, [256, 256], method='area')
    # Normalize image and return
    img = tf.cast(img, tf.float32) / 255.
    img = tf.expand_dims(img, axis=0)
    rm_wt = generator.predict(img, verbose=0)
    rm_wt = rm_wt['I'][0]
    rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
    out_img = (rm_wt * 255).astype(np.uint8)
  elif model == 'MS-UNet':
    generator = MS_UNet
    img = tf.image.resize(img, [256, 256], method='area')
    # Normalize image and return
    img = (tf.cast(img, tf.float32) - 127.5) / 127.5
    img = tf.expand_dims(img, axis=0)
    rm_wt = generator.predict(img, verbose=0)
    rm_wt = rm_wt[0]
    rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
    out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8)
  elif model == 'Pix2Pix':
    generator = Pix2Pix
    img = tf.image.resize(img, [256, 256], method='area')
    # Normalize image and return
    img = tf.cast(img, tf.float32) / 255.
    img = tf.expand_dims(img, axis=0)
    rm_wt = generator.predict(img, verbose=0)
    rm_wt = rm_wt[0]
    rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
    out_img = (rm_wt * 255).astype(np.uint8)
  return org_img, out_img

def infer_v1(img_path, model="WD_Net"):
  # Read image
  img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3)
  # Image original shape
  org_shape = img.shape
  # org_img = tf.image.resize(img, [256, 256], method='area')
  org_img = tf.cast(img, tf.uint8).numpy()
  # org_img = cv2.resize(org_img, (org_shape[1], org_shape[0]))

  # Choose model
  if model == 'WD-Net':
    generator = WD_Net_new
    img = tf.image.resize(img, [256, 256], method='area')
    # Normalize image and return
    img = tf.cast(img, tf.float32) / 255.
    img = tf.expand_dims(img, axis=0)
    rm_wt = generator.predict(img, verbose=0)
    rm_wt = rm_wt['I'][0]
    rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
    out_img = (rm_wt * 255).astype(np.uint8)
  elif model == 'MS-UNet':
    generator = MS_UNet
    img = tf.image.resize(img, [256, 256], method='area')
    # Normalize image and return
    img = (tf.cast(img, tf.float32) - 127.5) / 127.5
    img = tf.expand_dims(img, axis=0)
    rm_wt = generator.predict(img, verbose=0)
    rm_wt = rm_wt[0]
    rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
    out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8)
  elif model == 'Pix2Pix':
    generator = Pix2Pix
    img = tf.image.resize(img, [256, 256], method='area')
    # Normalize image and return
    img = tf.cast(img, tf.float32) / 255.
    img = tf.expand_dims(img, axis=0)
    rm_wt = generator.predict(img, verbose=0)
    rm_wt = rm_wt[0]
    rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
    out_img = (rm_wt * 255).astype(np.uint8)
  return org_img, out_img

# Main gradio code
# Define data and sort it
data = os.listdir('./data')
data.sort()

# Model list
model_list = ['WD-Net', 'MS-UNet', 'Pix2Pix']

demo = gr.Interface(
    fn=infer,
    inputs=[gr.Image(label="Choose an Image"), gr.Dropdown(model_list, label="Model")],
    outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")],
)

demo_v1 = gr.Interface(
    fn=infer_v1,
    inputs=[gr.Dropdown(data, label="Choose an Image"), gr.Dropdown(model_list, label="Model")],
    outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")],
)

tabbed_interface = gr.TabbedInterface([demo, demo_v1], ["Document", "Patch"], title="Watermark Removal")

tabbed_interface.launch()