Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| import tensorflow_hub as hub | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import numpy as np | |
| import functools | |
| from matplotlib import gridspec | |
| import gradio | |
| def load_img(img): | |
| # max_dim = 256 | |
| # img = tf.image.convert_image_dtype(img, tf.float32) | |
| # shape = tf.cast(np.shape(img)[:-1], tf.float32) | |
| # long_dim = max(shape) | |
| # scale = max_dim / long_dim | |
| # new_shape = tf.cast(shape * scale, tf.int32) | |
| # img=tf.convert_to_tensor(img) | |
| # img = tf.image.convert_image_dtype(img, tf.float32) | |
| # img = tf.image.resize(img, (256,256)) | |
| # img = img[tf.newaxis, :] | |
| max_dim = 256 | |
| img = tf.image.convert_image_dtype(img, tf.float32) | |
| shape = tf.cast(np.shape(img)[:-1], tf.float32) | |
| long_dim = max(shape) | |
| scale = max_dim / long_dim | |
| new_shape = tf.cast(shape * scale, tf.int32) | |
| img = tf.image.resize(img, new_shape) | |
| img = img[tf.newaxis, :] | |
| return img | |
| def crop_center(image): | |
| """Returns a cropped square image.""" | |
| shape = image.shape | |
| new_shape = min(shape[1], shape[2]) | |
| offset_y = max(shape[1] - shape[2], 0) // 2 | |
| offset_x = max(shape[2] - shape[1], 0) // 2 | |
| image=tf.image.crop_to_bounding_box( | |
| image, offset_y, offset_x, new_shape, new_shape) | |
| return image | |
| def load_image(img, image_size=(256, 256), preserve_aspect_ratio=True): | |
| """Loads and preprocesses images.""" | |
| # Cache image file locally. | |
| # image_path = tf.keras.utils.get_file(os.path.basename(image_url)[-128:], image_url) | |
| # Load and convert to float32 numpy array, add batch dimension, and normalize to range [0, 1]. | |
| # img = tf.io.decode_image( | |
| # tf.io.read_file(image_url), | |
| # channels=3, dtype=tf.float32)[tf.newaxis, ...] | |
| max_dim = 256 | |
| img = tf.image.convert_image_dtype(img, tf.float32) | |
| shape = tf.cast(np.shape(img)[:-1], tf.float32) | |
| long_dim = max(shape) | |
| scale = max_dim / long_dim | |
| new_shape = tf.cast(shape * scale, tf.int32) | |
| #img = crop_center(img) | |
| img = tf.image.resize(img, new_shape, preserve_aspect_ratio=True) | |
| img = img[tf.newaxis, :] | |
| return img | |
| def show_n(images, titles=('',)): | |
| n = len(images) | |
| image_sizes = [image.shape[1] for image in images] | |
| w = (image_sizes[0] * 6) // 320 | |
| plt.figure(figsize=(w * n, w)) | |
| gs = gridspec.GridSpec(1, n, width_ratios=image_sizes) | |
| for i in range(n): | |
| plt.subplot(gs[i]) | |
| plt.imshow(images[i][0], aspect='equal') | |
| plt.axis('off') | |
| plt.title(titles[i] if len(titles) > i else '') | |
| plt.show() | |
| def load_content_style_img(style_image,content_image): | |
| style_image=np.array(style_image) | |
| content_image=np.array(content_image) | |
| width,height=content_image.shape[1],content_image.shape[0] | |
| content_image = load_img(content_image) | |
| style_image = load_img(style_image) | |
| #content_image = crop_center(content_image) | |
| content_image = tf.image.resize(content_image, (width,height), preserve_aspect_ratio=True) | |
| style_image = crop_center(style_image) | |
| style_image = tf.image.resize(style_image, (256,256), preserve_aspect_ratio=True) | |
| style_image = tf.nn.avg_pool(style_image, ksize=[3,3], strides=[1,1], padding='SAME') | |
| return style_image,content_image | |
| # style_image,content_image=load_content_style_img(style,content) | |
| # display([content_image, style_image]) | |
| #show_n([content_image, style_image], ['Content image', 'Style image']) | |
| hub_handle = 'https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2' | |
| hub_module = hub.load(hub_handle) | |
| def tensor_to_image(tensor): | |
| tensor = tensor*255 | |
| tensor = np.array(tensor, dtype=np.uint8) | |
| if np.ndim(tensor)>3: | |
| assert tensor.shape[0] == 1 | |
| tensor = tensor[0] | |
| return Image.fromarray(tensor) | |
| stylized_image=0 | |
| def train(style,content): | |
| style_image,content_image=load_content_style_img(style,content) | |
| outputs = hub_module(tf.constant(content_image), tf.constant(style_image)) | |
| stylized_image = outputs[0] | |
| stylized_image=tensor_to_image(stylized_image) | |
| return stylized_image | |
| gr=gradio.Interface(fn=train, inputs=['image','image'], outputs=[gradio.Image(label='output').style(height=600)]) | |
| gr.launch(share=False) |