File size: 2,384 Bytes
1daae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Conv2DTranspose, Concatenate, Activation
from tensorflow.keras.models import Model
def load_model():
base_model = MobileNetV2(input_shape=(None, None, 3), include_top=False)
layer_names = [
'block_1_expand_relu', # 64x64
'block_3_expand_relu', # 32x32
'block_6_expand_relu', # 16x16
'block_13_expand_relu', # 8x8
'block_16_project', # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]
down_stack = Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False
up_stack = [
Conv2DTranspose(512, 3, strides=2, padding='same'),
Concatenate(),
Activation('relu'),
Conv2DTranspose(256, 3, strides=2, padding='same'),
Concatenate(),
Activation('relu'),
Conv2DTranspose(128, 3, strides=2, padding='same'),
Concatenate(),
Activation('relu'),
Conv2DTranspose(64, 3, strides=2, padding='same'),
Concatenate(),
Activation('relu'),
Conv2DTranspose(3, 3, strides=2, padding='same'),
Activation('sigmoid'),
]
inputs = tf.keras.layers.Input(shape=[None, None, 3])
x = inputs
skips = down_stack(x)
skips = reversed(skips[:-1])
for up, skip in zip(up_stack, skips):
x = up(x)
concat = Concatenate()
x = concat([x, skip])
model = Model(inputs=inputs, outputs=x)
return model
model = load_model()
def sketch_image(img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img / 255.0
img = cv2.resize(img, (256, 256))
img = np.expand_dims(img, axis=0)
output = model.predict(img)
output = np.squeeze(output, axis=0)
output = cv2.resize(output, (img.shape[2], img.shape[1]))
output = np.clip(output * 255, 0, 255).astype('uint8')
return output
title = "Picture to Drawing"
description = "Turn your pictures into beautiful drawings!"
inputs = gr.inputs.Image(label="Input Image")
outputs = gr.outputs.Image(label="Output Image")
examples = [['examples/1.jpg'], ['examples/2.jpg'], ['examples/3.jpg']]
gr.Interface(sketch_image, inputs, outputs, title=title, description=description, examples=examples).launch()
|