Spaces:
Runtime error
Runtime error
File size: 4,745 Bytes
1ff6b95 4637a55 1ff6b95 4637a55 1ff6b95 4637a55 1ff6b95 4637a55 1ff6b95 4637a55 1ff6b95 4637a55 1ff6b95 4637a55 1ff6b95 4637a55 1ff6b95 4637a55 1ff6b95 4637a55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Import all necessary libs
import tensorflow as tf
import keras
import numpy as np
import cv2
import gradio as gr
import os
# Load model
# Load Pix2Pix
pix2pix_path = './model/wt_generator_best.keras'
Pix2Pix = keras.saving.load_model(pix2pix_path)
# Load MS-UNet
unet_path = './model/unet.keras'
MS_UNet = keras.saving.load_model(unet_path)
# Load WD-Net
@keras.saving.register_keras_serializable(package="Clip")
class Clip(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, input):
return tf.clip_by_value(input, 0, 1)
old_gen_path = './model/generator_epoch_10.keras'
WD_Net_old = keras.saving.load_model(old_gen_path)
new_gen_path = './model/WD-Net_generator.keras'
WD_Net_new = keras.saving.load_model(new_gen_path)
# Define infer function
def infer(img, model='WD-Net'):
# Read image
# img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3)
# Image original shape
org_shape = img.shape
org_img = tf.image.resize(img, [256, 256], method='area')
org_img = tf.cast(org_img, tf.uint8).numpy()
org_img = cv2.resize(org_img, (org_shape[1], org_shape[0]))
# Choose model
if model == 'WD-Net':
generator = WD_Net_old
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = tf.cast(img, tf.float32) / 255.
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt['I'][0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = (rm_wt * 255).astype(np.uint8)
elif model == 'MS-UNet':
generator = MS_UNet
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = (tf.cast(img, tf.float32) - 127.5) / 127.5
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt[0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8)
elif model == 'Pix2Pix':
generator = Pix2Pix
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = tf.cast(img, tf.float32) / 255.
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt[0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = (rm_wt * 255).astype(np.uint8)
return org_img, out_img
def infer_v1(img_path, model="WD_Net"):
# Read image
img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3)
# Image original shape
org_shape = img.shape
# org_img = tf.image.resize(img, [256, 256], method='area')
org_img = tf.cast(img, tf.uint8).numpy()
# org_img = cv2.resize(org_img, (org_shape[1], org_shape[0]))
# Choose model
if model == 'WD-Net':
generator = WD_Net_new
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = tf.cast(img, tf.float32) / 255.
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt['I'][0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = (rm_wt * 255).astype(np.uint8)
elif model == 'MS-UNet':
generator = MS_UNet
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = (tf.cast(img, tf.float32) - 127.5) / 127.5
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt[0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8)
elif model == 'Pix2Pix':
generator = Pix2Pix
img = tf.image.resize(img, [256, 256], method='area')
# Normalize image and return
img = tf.cast(img, tf.float32) / 255.
img = tf.expand_dims(img, axis=0)
rm_wt = generator.predict(img, verbose=0)
rm_wt = rm_wt[0]
rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
out_img = (rm_wt * 255).astype(np.uint8)
return org_img, out_img
# Main gradio code
# Define data and sort it
data = os.listdir('./data')
data.sort()
# Model list
model_list = ['WD-Net', 'MS-UNet', 'Pix2Pix']
demo = gr.Interface(
fn=infer,
inputs=[gr.Image(label="Choose an Image"), gr.Dropdown(model_list, label="Model")],
outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")],
)
demo_v1 = gr.Interface(
fn=infer_v1,
inputs=[gr.Dropdown(data, label="Choose an Image"), gr.Dropdown(model_list, label="Model")],
outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")],
)
tabbed_interface = gr.TabbedInterface([demo, demo_v1], ["Document", "Patch"], title="Watermark Removal")
tabbed_interface.launch() |