import tensorflow as tf import tensorflow_hub as hub import matplotlib.pyplot as plt from PIL import Image import numpy as np import functools from matplotlib import gridspec import gradio def load_img(img): # max_dim = 256 # img = tf.image.convert_image_dtype(img, tf.float32) # shape = tf.cast(np.shape(img)[:-1], tf.float32) # long_dim = max(shape) # scale = max_dim / long_dim # new_shape = tf.cast(shape * scale, tf.int32) # img=tf.convert_to_tensor(img) # img = tf.image.convert_image_dtype(img, tf.float32) # img = tf.image.resize(img, (256,256)) # img = img[tf.newaxis, :] max_dim = 256 img = tf.image.convert_image_dtype(img, tf.float32) shape = tf.cast(np.shape(img)[:-1], tf.float32) long_dim = max(shape) scale = max_dim / long_dim new_shape = tf.cast(shape * scale, tf.int32) img = tf.image.resize(img, new_shape) img = img[tf.newaxis, :] return img def crop_center(image): """Returns a cropped square image.""" shape = image.shape new_shape = min(shape[1], shape[2]) offset_y = max(shape[1] - shape[2], 0) // 2 offset_x = max(shape[2] - shape[1], 0) // 2 image=tf.image.crop_to_bounding_box( image, offset_y, offset_x, new_shape, new_shape) return image @functools.lru_cache(maxsize=None) def load_image(img, image_size=(256, 256), preserve_aspect_ratio=True): """Loads and preprocesses images.""" # Cache image file locally. # image_path = tf.keras.utils.get_file(os.path.basename(image_url)[-128:], image_url) # Load and convert to float32 numpy array, add batch dimension, and normalize to range [0, 1]. # img = tf.io.decode_image( # tf.io.read_file(image_url), # channels=3, dtype=tf.float32)[tf.newaxis, ...] max_dim = 256 img = tf.image.convert_image_dtype(img, tf.float32) shape = tf.cast(np.shape(img)[:-1], tf.float32) long_dim = max(shape) scale = max_dim / long_dim new_shape = tf.cast(shape * scale, tf.int32) #img = crop_center(img) img = tf.image.resize(img, new_shape, preserve_aspect_ratio=True) img = img[tf.newaxis, :] return img def show_n(images, titles=('',)): n = len(images) image_sizes = [image.shape[1] for image in images] w = (image_sizes[0] * 6) // 320 plt.figure(figsize=(w * n, w)) gs = gridspec.GridSpec(1, n, width_ratios=image_sizes) for i in range(n): plt.subplot(gs[i]) plt.imshow(images[i][0], aspect='equal') plt.axis('off') plt.title(titles[i] if len(titles) > i else '') plt.show() def load_content_style_img(style_image,content_image): style_image=np.array(style_image) content_image=np.array(content_image) width,height=content_image.shape[1],content_image.shape[0] content_image = load_img(content_image) style_image = load_img(style_image) #content_image = crop_center(content_image) content_image = tf.image.resize(content_image, (width,height), preserve_aspect_ratio=True) style_image = crop_center(style_image) style_image = tf.image.resize(style_image, (256,256), preserve_aspect_ratio=True) style_image = tf.nn.avg_pool(style_image, ksize=[3,3], strides=[1,1], padding='SAME') return style_image,content_image # style_image,content_image=load_content_style_img(style,content) # display([content_image, style_image]) #show_n([content_image, style_image], ['Content image', 'Style image']) hub_handle = 'https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2' hub_module = hub.load(hub_handle) def tensor_to_image(tensor): tensor = tensor*255 tensor = np.array(tensor, dtype=np.uint8) if np.ndim(tensor)>3: assert tensor.shape[0] == 1 tensor = tensor[0] return Image.fromarray(tensor) stylized_image=0 def train(style,content): style_image,content_image=load_content_style_img(style,content) outputs = hub_module(tf.constant(content_image), tf.constant(style_image)) stylized_image = outputs[0] stylized_image=tensor_to_image(stylized_image) return stylized_image gr=gradio.Interface(fn=train, inputs=['image','image'], outputs=[gradio.Image(label='output').style(height=600)]) gr.launch(share=False)