|
|
import gradio as gr |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from tensorflow.keras.applications import MobileNetV2 |
|
|
from tensorflow.keras.layers import Conv2DTranspose, Concatenate, Activation |
|
|
from tensorflow.keras.models import Model |
|
|
|
|
|
def load_model(): |
|
|
base_model = MobileNetV2(input_shape=(None, None, 3), include_top=False) |
|
|
layer_names = [ |
|
|
'block_1_expand_relu', |
|
|
'block_3_expand_relu', |
|
|
'block_6_expand_relu', |
|
|
'block_13_expand_relu', |
|
|
'block_16_project', |
|
|
] |
|
|
layers = [base_model.get_layer(name).output for name in layer_names] |
|
|
down_stack = Model(inputs=base_model.input, outputs=layers) |
|
|
down_stack.trainable = False |
|
|
|
|
|
up_stack = [ |
|
|
Conv2DTranspose(512, 3, strides=2, padding='same'), |
|
|
Concatenate(), |
|
|
Activation('relu'), |
|
|
Conv2DTranspose(256, 3, strides=2, padding='same'), |
|
|
Concatenate(), |
|
|
Activation('relu'), |
|
|
Conv2DTranspose(128, 3, strides=2, padding='same'), |
|
|
Concatenate(), |
|
|
Activation('relu'), |
|
|
Conv2DTranspose(64, 3, strides=2, padding='same'), |
|
|
Concatenate(), |
|
|
Activation('relu'), |
|
|
Conv2DTranspose(3, 3, strides=2, padding='same'), |
|
|
Activation('sigmoid'), |
|
|
] |
|
|
|
|
|
inputs = tf.keras.layers.Input(shape=[None, None, 3]) |
|
|
x = inputs |
|
|
|
|
|
skips = down_stack(x) |
|
|
skips = reversed(skips[:-1]) |
|
|
|
|
|
for up, skip in zip(up_stack, skips): |
|
|
x = up(x) |
|
|
concat = Concatenate() |
|
|
x = concat([x, skip]) |
|
|
|
|
|
model = Model(inputs=inputs, outputs=x) |
|
|
|
|
|
return model |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
def sketch_image(img): |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
img = img / 255.0 |
|
|
img = cv2.resize(img, (256, 256)) |
|
|
img = np.expand_dims(img, axis=0) |
|
|
output = model.predict(img) |
|
|
output = np.squeeze(output, axis=0) |
|
|
output = cv2.resize(output, (img.shape[2], img.shape[1])) |
|
|
output = np.clip(output * 255, 0, 255).astype('uint8') |
|
|
return output |
|
|
|
|
|
title = "Picture to Drawing" |
|
|
description = "Turn your pictures into beautiful drawings!" |
|
|
inputs = gr.inputs.Image(label="Input Image") |
|
|
outputs = gr.outputs.Image(label="Output Image") |
|
|
examples = [['examples/1.jpg'], ['examples/2.jpg'], ['examples/3.jpg']] |
|
|
gr.Interface(sketch_image, inputs, outputs, title=title, description=description, examples=examples).launch() |
|
|
|