Spaces:
Runtime error
Runtime error
| # Import all necessary libs | |
| import tensorflow as tf | |
| import keras | |
| import numpy as np | |
| import cv2 | |
| import gradio as gr | |
| import os | |
| # Load model | |
| # Load Pix2Pix | |
| pix2pix_path = './model/wt_generator_best.keras' | |
| Pix2Pix = keras.saving.load_model(pix2pix_path) | |
| # Load MS-UNet | |
| unet_path = './model/unet.keras' | |
| MS_UNet = keras.saving.load_model(unet_path) | |
| # Load WD-Net | |
| class Clip(keras.layers.Layer): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def call(self, input): | |
| return tf.clip_by_value(input, 0, 1) | |
| old_gen_path = './model/generator_epoch_10.keras' | |
| WD_Net_old = keras.saving.load_model(old_gen_path) | |
| new_gen_path = './model/WD-Net_generator.keras' | |
| WD_Net_new = keras.saving.load_model(new_gen_path) | |
| # Define infer function | |
| def infer(img, model='WD-Net'): | |
| # Read image | |
| # img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3) | |
| # Image original shape | |
| org_shape = img.shape | |
| org_img = tf.image.resize(img, [256, 256], method='area') | |
| org_img = tf.cast(org_img, tf.uint8).numpy() | |
| org_img = cv2.resize(org_img, (org_shape[1], org_shape[0])) | |
| # Choose model | |
| if model == 'WD-Net': | |
| generator = WD_Net_old | |
| img = tf.image.resize(img, [256, 256], method='area') | |
| # Normalize image and return | |
| img = tf.cast(img, tf.float32) / 255. | |
| img = tf.expand_dims(img, axis=0) | |
| rm_wt = generator.predict(img, verbose=0) | |
| rm_wt = rm_wt['I'][0] | |
| rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) | |
| out_img = (rm_wt * 255).astype(np.uint8) | |
| elif model == 'MS-UNet': | |
| generator = MS_UNet | |
| img = tf.image.resize(img, [256, 256], method='area') | |
| # Normalize image and return | |
| img = (tf.cast(img, tf.float32) - 127.5) / 127.5 | |
| img = tf.expand_dims(img, axis=0) | |
| rm_wt = generator.predict(img, verbose=0) | |
| rm_wt = rm_wt[0] | |
| rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) | |
| out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8) | |
| elif model == 'Pix2Pix': | |
| generator = Pix2Pix | |
| img = tf.image.resize(img, [256, 256], method='area') | |
| # Normalize image and return | |
| img = tf.cast(img, tf.float32) / 255. | |
| img = tf.expand_dims(img, axis=0) | |
| rm_wt = generator.predict(img, verbose=0) | |
| rm_wt = rm_wt[0] | |
| rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) | |
| out_img = (rm_wt * 255).astype(np.uint8) | |
| return org_img, out_img | |
| def infer_v1(img_path, model="WD_Net"): | |
| # Read image | |
| img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3) | |
| # Image original shape | |
| org_shape = img.shape | |
| # org_img = tf.image.resize(img, [256, 256], method='area') | |
| org_img = tf.cast(img, tf.uint8).numpy() | |
| # org_img = cv2.resize(org_img, (org_shape[1], org_shape[0])) | |
| # Choose model | |
| if model == 'WD-Net': | |
| generator = WD_Net_new | |
| img = tf.image.resize(img, [256, 256], method='area') | |
| # Normalize image and return | |
| img = tf.cast(img, tf.float32) / 255. | |
| img = tf.expand_dims(img, axis=0) | |
| rm_wt = generator.predict(img, verbose=0) | |
| rm_wt = rm_wt['I'][0] | |
| rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) | |
| out_img = (rm_wt * 255).astype(np.uint8) | |
| elif model == 'MS-UNet': | |
| generator = MS_UNet | |
| img = tf.image.resize(img, [256, 256], method='area') | |
| # Normalize image and return | |
| img = (tf.cast(img, tf.float32) - 127.5) / 127.5 | |
| img = tf.expand_dims(img, axis=0) | |
| rm_wt = generator.predict(img, verbose=0) | |
| rm_wt = rm_wt[0] | |
| rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) | |
| out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8) | |
| elif model == 'Pix2Pix': | |
| generator = Pix2Pix | |
| img = tf.image.resize(img, [256, 256], method='area') | |
| # Normalize image and return | |
| img = tf.cast(img, tf.float32) / 255. | |
| img = tf.expand_dims(img, axis=0) | |
| rm_wt = generator.predict(img, verbose=0) | |
| rm_wt = rm_wt[0] | |
| rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0])) | |
| out_img = (rm_wt * 255).astype(np.uint8) | |
| return org_img, out_img | |
| # Main gradio code | |
| # Define data and sort it | |
| data = os.listdir('./data') | |
| data.sort() | |
| # Model list | |
| model_list = ['WD-Net', 'MS-UNet', 'Pix2Pix'] | |
| demo = gr.Interface( | |
| fn=infer, | |
| inputs=[gr.Image(label="Choose an Image"), gr.Dropdown(model_list, label="Model")], | |
| outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")], | |
| ) | |
| demo_v1 = gr.Interface( | |
| fn=infer_v1, | |
| inputs=[gr.Dropdown(data, label="Choose an Image"), gr.Dropdown(model_list, label="Model")], | |
| outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")], | |
| ) | |
| tabbed_interface = gr.TabbedInterface([demo, demo_v1], ["Document", "Patch"], title="Watermark Removal") | |
| tabbed_interface.launch() |