import gradio as gr import cv2 import numpy as np import tensorflow as tf from tensorflow.keras.applications import MobileNetV2 from tensorflow.keras.layers import Conv2DTranspose, Concatenate, Activation from tensorflow.keras.models import Model def load_model(): base_model = MobileNetV2(input_shape=(None, None, 3), include_top=False) layer_names = [ 'block_1_expand_relu', # 64x64 'block_3_expand_relu', # 32x32 'block_6_expand_relu', # 16x16 'block_13_expand_relu', # 8x8 'block_16_project', # 4x4 ] layers = [base_model.get_layer(name).output for name in layer_names] down_stack = Model(inputs=base_model.input, outputs=layers) down_stack.trainable = False up_stack = [ Conv2DTranspose(512, 3, strides=2, padding='same'), Concatenate(), Activation('relu'), Conv2DTranspose(256, 3, strides=2, padding='same'), Concatenate(), Activation('relu'), Conv2DTranspose(128, 3, strides=2, padding='same'), Concatenate(), Activation('relu'), Conv2DTranspose(64, 3, strides=2, padding='same'), Concatenate(), Activation('relu'), Conv2DTranspose(3, 3, strides=2, padding='same'), Activation('sigmoid'), ] inputs = tf.keras.layers.Input(shape=[None, None, 3]) x = inputs skips = down_stack(x) skips = reversed(skips[:-1]) for up, skip in zip(up_stack, skips): x = up(x) concat = Concatenate() x = concat([x, skip]) model = Model(inputs=inputs, outputs=x) return model model = load_model() def sketch_image(img): img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img / 255.0 img = cv2.resize(img, (256, 256)) img = np.expand_dims(img, axis=0) output = model.predict(img) output = np.squeeze(output, axis=0) output = cv2.resize(output, (img.shape[2], img.shape[1])) output = np.clip(output * 255, 0, 255).astype('uint8') return output title = "Picture to Drawing" description = "Turn your pictures into beautiful drawings!" inputs = gr.inputs.Image(label="Input Image") outputs = gr.outputs.Image(label="Output Image") examples = [['examples/1.jpg'], ['examples/2.jpg'], ['examples/3.jpg']] gr.Interface(sketch_image, inputs, outputs, title=title, description=description, examples=examples).launch()