1plus1's picture
Update app.py
4637a55 verified
# Import all necessary libs
import tensorflow as tf
import keras
import numpy as np
import cv2
import gradio as gr
import os
# Load model
# Load Pix2Pix
pix2pix_path = './model/wt_generator_best.keras'
Pix2Pix = keras.saving.load_model(pix2pix_path)
# Load MS-UNet
unet_path = './model/unet.keras'
MS_UNet = keras.saving.load_model(unet_path)
# Load WD-Net
@keras.saving.register_keras_serializable(package="Clip")
class Clip(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, input):
return tf.clip_by_value(input, 0, 1)
old_gen_path = './model/generator_epoch_10.keras'
WD_Net_old = keras.saving.load_model(old_gen_path)
new_gen_path = './model/WD-Net_generator.keras'
WD_Net_new = keras.saving.load_model(new_gen_path)
# Define infer function
def infer(img, model='WD-Net'):
# Read image
# img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3)
# Image original shape
org_shape = img.shape
org_img = tf.image.resize(img, [256, 256], method='area')
org_img = tf.cast(org_img, tf.uint8).numpy()
org_img = cv2.resize(org_img, (org_shape[1], org_shape[0]))
# Choose model
if model == 'WD-Net':
generator = WD_Net_old
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = tf.cast(img, tf.float32) / 255.
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt['I'][0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = (rm_wt * 255).astype(np.uint8)
elif model == 'MS-UNet':
generator = MS_UNet
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = (tf.cast(img, tf.float32) - 127.5) / 127.5
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt[0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8)
elif model == 'Pix2Pix':
generator = Pix2Pix
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = tf.cast(img, tf.float32) / 255.
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt[0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = (rm_wt * 255).astype(np.uint8)
return org_img, out_img
def infer_v1(img_path, model="WD_Net"):
# Read image
img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3)
# Image original shape
org_shape = img.shape
# org_img = tf.image.resize(img, [256, 256], method='area')
org_img = tf.cast(img, tf.uint8).numpy()
# org_img = cv2.resize(org_img, (org_shape[1], org_shape[0]))
# Choose model
if model == 'WD-Net':
generator = WD_Net_new
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = tf.cast(img, tf.float32) / 255.
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt['I'][0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = (rm_wt * 255).astype(np.uint8)
elif model == 'MS-UNet':
generator = MS_UNet
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = (tf.cast(img, tf.float32) - 127.5) / 127.5
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt[0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8)
elif model == 'Pix2Pix':
generator = Pix2Pix
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = tf.cast(img, tf.float32) / 255.
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt[0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = (rm_wt * 255).astype(np.uint8)
return org_img, out_img
# Main gradio code
# Define data and sort it
data = os.listdir('./data')
data.sort()
# Model list
model_list = ['WD-Net', 'MS-UNet', 'Pix2Pix']
demo = gr.Interface(
fn=infer,
inputs=[gr.Image(label="Choose an Image"), gr.Dropdown(model_list, label="Model")],
outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")],
)
demo_v1 = gr.Interface(
fn=infer_v1,
inputs=[gr.Dropdown(data, label="Choose an Image"), gr.Dropdown(model_list, label="Model")],
outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")],
)
tabbed_interface = gr.TabbedInterface([demo, demo_v1], ["Document", "Patch"], title="Watermark Removal")
tabbed_interface.launch()