# Import all necessary libs import tensorflow as tf import keras import numpy as np import cv2 import gradio as gr import os # Load model # Load Pix2Pix pix2pix_path = './model/wt_generator_best.keras' Pix2Pix = keras.saving.load_model(pix2pix_path) # Load MS-UNet unet_path = './model/unet.keras' MS_UNet = keras.saving.load_model(unet_path) # Load WD-Net @keras.saving.register_keras_serializable(package="Clip") class Clip(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) def call(self, input): return tf.clip_by_value(input, 0, 1) old_gen_path = './model/generator_epoch_10.keras' WD_Net_old = keras.saving.load_model(old_gen_path) new_gen_path = './model/WD-Net_generator.keras' WD_Net_new = keras.saving.load_model(new_gen_path) # Define infer function def infer(img, model='WD-Net'): # Read image # img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3) # Image original shape org_shape = img.shape org_img = tf.image.resize(img, [256, 256], method='area') org_img = tf.cast(org_img, tf.uint8).numpy() org_img = cv2.resize(org_img, (org_shape[1], org_shape[0])) # Choose model if model == 'WD-Net': generator = WD_Net_old img = tf.image.resize(img, [256, 256], method='area') # Normalize image and return img = tf.cast(img, tf.float32) / 255. img = tf.expand_dims(img, axis=0) rm_wt = generator.predict(img, verbose=0) rm_wt = rm_wt['I'][0] rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) out_img = (rm_wt * 255).astype(np.uint8) elif model == 'MS-UNet': generator = MS_UNet img = tf.image.resize(img, [256, 256], method='area') # Normalize image and return img = (tf.cast(img, tf.float32) - 127.5) / 127.5 img = tf.expand_dims(img, axis=0) rm_wt = generator.predict(img, verbose=0) rm_wt = rm_wt[0] rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8) elif model == 'Pix2Pix': generator = Pix2Pix img = tf.image.resize(img, [256, 256], method='area') # Normalize image and return img = tf.cast(img, tf.float32) / 255. img = tf.expand_dims(img, axis=0) rm_wt = generator.predict(img, verbose=0) rm_wt = rm_wt[0] rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) out_img = (rm_wt * 255).astype(np.uint8) return org_img, out_img def infer_v1(img_path, model="WD_Net"): # Read image img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3) # Image original shape org_shape = img.shape # org_img = tf.image.resize(img, [256, 256], method='area') org_img = tf.cast(img, tf.uint8).numpy() # org_img = cv2.resize(org_img, (org_shape[1], org_shape[0])) # Choose model if model == 'WD-Net': generator = WD_Net_new img = tf.image.resize(img, [256, 256], method='area') # Normalize image and return img = tf.cast(img, tf.float32) / 255. img = tf.expand_dims(img, axis=0) rm_wt = generator.predict(img, verbose=0) rm_wt = rm_wt['I'][0] rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) out_img = (rm_wt * 255).astype(np.uint8) elif model == 'MS-UNet': generator = MS_UNet img = tf.image.resize(img, [256, 256], method='area') # Normalize image and return img = (tf.cast(img, tf.float32) - 127.5) / 127.5 img = tf.expand_dims(img, axis=0) rm_wt = generator.predict(img, verbose=0) rm_wt = rm_wt[0] rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8) elif model == 'Pix2Pix': generator = Pix2Pix img = tf.image.resize(img, [256, 256], method='area') # Normalize image and return img = tf.cast(img, tf.float32) / 255. img = tf.expand_dims(img, axis=0) rm_wt = generator.predict(img, verbose=0) rm_wt = rm_wt[0] rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) out_img = (rm_wt * 255).astype(np.uint8) return org_img, out_img # Main gradio code # Define data and sort it data = os.listdir('./data') data.sort() # Model list model_list = ['WD-Net', 'MS-UNet', 'Pix2Pix'] demo = gr.Interface( fn=infer, inputs=[gr.Image(label="Choose an Image"), gr.Dropdown(model_list, label="Model")], outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")], ) demo_v1 = gr.Interface( fn=infer_v1, inputs=[gr.Dropdown(data, label="Choose an Image"), gr.Dropdown(model_list, label="Model")], outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")], ) tabbed_interface = gr.TabbedInterface([demo, demo_v1], ["Document", "Patch"], title="Watermark Removal") tabbed_interface.launch()